diff --git a/src/VirtoCommerce.ImportModule.CsvHelper/CsvDataReader.cs b/src/VirtoCommerce.ImportModule.CsvHelper/CsvDataReader.cs index 7cd17a3..5f76a80 100644 --- a/src/VirtoCommerce.ImportModule.CsvHelper/CsvDataReader.cs +++ b/src/VirtoCommerce.ImportModule.CsvHelper/CsvDataReader.cs @@ -14,14 +14,15 @@ namespace VirtoCommerce.ImportModule.CsvHelper { public class CsvDataReader : IImportDataReader where TCsvClassMap : ClassMap { - private readonly Stream _stream; - private readonly int _pageSize; private readonly bool _needReadRaw; - private int? _totalCount; - private string _headerRaw; + private bool _disposed; protected CsvConfiguration CsvConfiguration { get; set; } + protected readonly Stream Stream; + protected readonly Stream CountStream; protected readonly CsvReader CsvReader; + protected string HeaderRaw; + protected int? TotalCount; public bool HasMoreResults { get; private set; } = true; @@ -29,8 +30,8 @@ public CsvDataReader(Stream stream, ImportContext context, bool needReadRaw = fa { CsvConfiguration = GetConfiguration(context); - _stream = stream; - CsvReader = new CsvReader(new StreamReader(_stream), CsvConfiguration); + Stream = stream; + CsvReader = new CsvReader(new StreamReader(Stream), CsvConfiguration); CsvReader.Context.RegisterClassMap(); _pageSize = Convert.ToInt32(context.ImportProfile.Settings.FirstOrDefault(x => x.Name == CsvSettings.PageSize.Name)?.Value ?? 50); @@ -41,49 +42,73 @@ public CsvDataReader(Stream stream, ImportContext context, CsvConfiguration csvC { CsvConfiguration = MergeWithDefaultConfig(csvConfiguration, context); - _stream = stream; - CsvReader = new CsvReader(new StreamReader(_stream), CsvConfiguration); + Stream = stream; + CsvReader = new CsvReader(new StreamReader(Stream), CsvConfiguration); CsvReader.Context.RegisterClassMap(); _pageSize = Convert.ToInt32(context.ImportProfile.Settings.FirstOrDefault(x => x.Name == CsvSettings.PageSize.Name)?.Value ?? 50); _needReadRaw = needReadRaw; } + public CsvDataReader(Stream stream, Stream countStream, ImportContext context, bool needReadRaw = false) + : this(stream, context, needReadRaw) + { + CountStream = countStream; + } + + public CsvDataReader(Stream stream, Stream countStream, ImportContext context, CsvConfiguration csvConfiguration, bool needReadRaw = false) + : this(stream, context, csvConfiguration, needReadRaw) + { + CountStream = countStream; + } + public virtual async Task GetTotalCountAsync(ImportContext context) { - if (_totalCount.HasValue) + if (TotalCount.HasValue) { - return _totalCount.Value; + return TotalCount.Value; + } + + Stream stream; + bool leaveOpen; + if (Stream.CanSeek) + { + stream = Stream; + leaveOpen = true; + } + else + { + stream = CountStream ?? throw new InvalidOperationException("Count stream is not provided."); + leaveOpen = false; } var streamPosition = 0L; - if (_stream.CanSeek) + if (stream.CanSeek) { - streamPosition = _stream.Position; - _stream.Seek(0, SeekOrigin.Begin); + streamPosition = stream.Position; + stream.Seek(0, SeekOrigin.Begin); } - var streamReader = new StreamReader(_stream, leaveOpen: true); - var csvReader = new CsvReader(streamReader, CsvConfiguration); + using var csvReader = new CsvReader(new StreamReader(stream), CsvConfiguration, leaveOpen); await csvReader.ReadAsync(); csvReader.ReadHeader(); - _headerRaw = string.Join(csvReader.Configuration.Delimiter, csvReader.HeaderRecord); + HeaderRaw = string.Join(csvReader.Configuration.Delimiter, csvReader.HeaderRecord); - _totalCount = 0; + TotalCount = 0; while (await csvReader.ReadAsync()) { - _totalCount++; + TotalCount++; } - if (_stream.CanSeek) + if (stream.CanSeek) { - _stream.Seek(streamPosition, SeekOrigin.Begin); + stream.Seek(streamPosition, SeekOrigin.Begin); } - return _totalCount.Value; + return TotalCount.Value; } public virtual async Task ReadNextPageAsync(ImportContext context) @@ -110,7 +135,7 @@ public virtual async Task ReadNextPageAsync(ImportContext context) result.Add(new CsvImportRecord { Row = row, - RawHeader = _headerRaw, + RawHeader = HeaderRaw, RawRecord = rawRecord, Record = record, }); @@ -195,21 +220,30 @@ public void Dispose() protected virtual void Dispose(bool disposing) { - CsvReader.Dispose(); - _stream?.Dispose(); + if (_disposed) + { + return; + } + if (disposing) + { + CsvReader.Dispose(); + Stream?.Dispose(); + CountStream?.Dispose(); + } + _disposed = true; } private CsvConfiguration MergeWithDefaultConfig(CsvConfiguration csvConfiguration, ImportContext context) { var defaultCsvConfiguration = GetConfiguration(context); - var result = csvConfiguration; - result.Delimiter = result.Delimiter ?? defaultCsvConfiguration.Delimiter; - result.PrepareHeaderForMatch = result.PrepareHeaderForMatch ?? defaultCsvConfiguration.PrepareHeaderForMatch; - result.BadDataFound = result.BadDataFound ?? defaultCsvConfiguration.BadDataFound; - result.ReadingExceptionOccurred = result.ReadingExceptionOccurred ?? defaultCsvConfiguration.ReadingExceptionOccurred; - result.MissingFieldFound = result.MissingFieldFound ?? defaultCsvConfiguration.MissingFieldFound; - - return result; + + csvConfiguration.Delimiter ??= defaultCsvConfiguration.Delimiter; + csvConfiguration.PrepareHeaderForMatch ??= defaultCsvConfiguration.PrepareHeaderForMatch; + csvConfiguration.BadDataFound ??= defaultCsvConfiguration.BadDataFound; + csvConfiguration.ReadingExceptionOccurred ??= defaultCsvConfiguration.ReadingExceptionOccurred; + csvConfiguration.MissingFieldFound ??= defaultCsvConfiguration.MissingFieldFound; + + return csvConfiguration; } } }