Skip to content

Commit 8796ee9

Browse files
committed
Update DataLoader.cs
1 parent d333a18 commit 8796ee9

File tree

1 file changed

+116
-16
lines changed

1 file changed

+116
-16
lines changed

src/TorchSharp/DataLoader.cs

Lines changed: 116 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,64 @@ public static partial class utils
1818
public static partial class data
1919
{
2020

21-
public static Modules.DataLoader DataLoader(Dataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
21+
public static Modules.DataLoader DataLoader(
22+
Dataset dataset,
23+
int batchSize, IEnumerable<long> shuffler,
24+
Device device = null,
25+
int num_worker = 1, bool drop_last = false,
26+
bool disposeBatch = true, bool disposeDataset = true)
2227
{
23-
return new Modules.DataLoader(dataset, batchSize,shuffler, device, num_worker, drop_last);
28+
return new Modules.DataLoader(
29+
dataset,
30+
batchSize, shuffler,
31+
device,
32+
num_worker, drop_last,
33+
disposeBatch, disposeDataset);
2434
}
2535

26-
public static Modules.DataLoader DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
36+
public static Modules.DataLoader DataLoader(
37+
Dataset dataset,
38+
int batchSize, bool shuffle = false,
39+
Device device = null, int? seed = null,
40+
int num_worker = 1, bool drop_last = false,
41+
bool disposeBatch = true, bool disposeDataset = true)
2742
{
28-
return new Modules.DataLoader(dataset,batchSize,shuffle, device, seed, num_worker,drop_last);
43+
return new Modules.DataLoader(
44+
dataset,
45+
batchSize, shuffle,
46+
device, seed,
47+
num_worker, drop_last,
48+
disposeBatch, disposeDataset);
2949
}
3050

31-
public static Modules.IterableDataLoader DataLoader(IterableDataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
51+
public static Modules.IterableDataLoader DataLoader(
52+
IterableDataset dataset,
53+
int batchSize, IEnumerable<long> shuffler,
54+
Device device = null,
55+
int num_worker = 1, bool drop_last = false,
56+
bool disposeBatch = true, bool disposeDataset = true)
3257
{
33-
return new Modules.IterableDataLoader(dataset, batchSize, shuffler, device, num_worker, drop_last);
58+
return new Modules.IterableDataLoader(
59+
dataset,
60+
batchSize, shuffler,
61+
device,
62+
num_worker, drop_last,
63+
disposeBatch, disposeDataset);
3464
}
3565

36-
public static Modules.IterableDataLoader DataLoader(IterableDataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
66+
public static Modules.IterableDataLoader DataLoader(
67+
IterableDataset dataset,
68+
int batchSize, bool shuffle = false,
69+
Device device = null, int? seed = null,
70+
int num_worker = 1, bool drop_last = false,
71+
bool disposeBatch = true, bool disposeDataset = true)
3772
{
38-
return new Modules.IterableDataLoader(dataset, batchSize, shuffle, device, seed, num_worker, drop_last);
73+
return new Modules.IterableDataLoader(
74+
dataset,
75+
batchSize, shuffle,
76+
device, seed,
77+
num_worker, drop_last,
78+
disposeBatch, disposeDataset);
3979
}
4080
}
4181
}
@@ -64,8 +104,23 @@ public class DataLoader : DataLoader<Dictionary<string, torch.Tensor>, Dictionar
64104
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
65105
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
66106
/// </param>
67-
public DataLoader(Dataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
68-
: base(dataset, batchSize, Collate, shuffler, device, num_worker, drop_last)
107+
/// <param name="disposeBatch">
108+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
109+
/// </param>
110+
/// <param name="disposeDataset">
111+
/// Indicates whether to dispose the dataset when being disposed.
112+
/// </param>
113+
public DataLoader(
114+
Dataset dataset,
115+
int batchSize, IEnumerable<long> shuffler,
116+
Device device = null,
117+
int num_worker = 1, bool drop_last = false,
118+
bool disposeBatch = true, bool disposeDataset = true)
119+
: base(dataset,
120+
batchSize, Collate, shuffler,
121+
device,
122+
num_worker, drop_last,
123+
disposeBatch, disposeDataset)
69124
{
70125
}
71126

@@ -82,8 +137,23 @@ public DataLoader(Dataset dataset, int batchSize, IEnumerable<long> shuffler, De
82137
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
83138
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
84139
/// </param>
85-
public DataLoader(Dataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
86-
: base(dataset, batchSize, Collate, shuffle, device, seed, num_worker, drop_last)
140+
/// <param name="disposeBatch">
141+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
142+
/// </param>
143+
/// <param name="disposeDataset">
144+
/// Indicates whether to dispose the dataset when being disposed.
145+
/// </param>
146+
public DataLoader(
147+
Dataset dataset,
148+
int batchSize, bool shuffle = false,
149+
Device device = null, int? seed = null,
150+
int num_worker = 1, bool drop_last = false,
151+
bool disposeBatch = true, bool disposeDataset = true)
152+
: base(dataset,
153+
batchSize, Collate, shuffle,
154+
device, seed,
155+
num_worker, drop_last,
156+
disposeBatch, disposeDataset)
87157
{
88158
}
89159

@@ -120,8 +190,23 @@ public class IterableDataLoader : DataLoader<IList<torch.Tensor>, IList<torch.Te
120190
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
121191
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
122192
/// </param>
123-
public IterableDataLoader(IterableDataset dataset, int batchSize, IEnumerable<long> shuffler, Device device = null, int num_worker = 1, bool drop_last = false)
124-
: base(dataset, batchSize, Collate, shuffler, device, num_worker, drop_last)
193+
/// <param name="disposeBatch">
194+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
195+
/// </param>
196+
/// <param name="disposeDataset">
197+
/// Indicates whether to dispose the dataset when being disposed.
198+
/// </param>
199+
public IterableDataLoader(
200+
IterableDataset dataset,
201+
int batchSize, IEnumerable<long> shuffler,
202+
Device device = null,
203+
int num_worker = 1, bool drop_last = false,
204+
bool disposeBatch = true, bool disposeDataset = true)
205+
: base(dataset,
206+
batchSize, Collate, shuffler,
207+
device,
208+
num_worker, drop_last,
209+
disposeBatch, disposeDataset)
125210
{
126211
}
127212

@@ -138,8 +223,23 @@ public IterableDataLoader(IterableDataset dataset, int batchSize, IEnumerable<lo
138223
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
139224
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
140225
/// </param>
141-
public IterableDataLoader(IterableDataset dataset, int batchSize, bool shuffle = false, Device device = null, int? seed = null, int num_worker = 1, bool drop_last = false)
142-
: base(dataset, batchSize, Collate, shuffle, device, seed, num_worker, drop_last)
226+
/// <param name="disposeBatch">
227+
/// Indicates whether to automatically dispose the collated tensors after an iteration.
228+
/// </param>
229+
/// <param name="disposeDataset">
230+
/// Indicates whether to dispose the dataset when being disposed.
231+
/// </param>
232+
public IterableDataLoader(
233+
IterableDataset dataset,
234+
int batchSize, bool shuffle = false,
235+
Device device = null, int? seed = null,
236+
int num_worker = 1, bool drop_last = false,
237+
bool disposeBatch = true, bool disposeDataset = true)
238+
: base(dataset,
239+
batchSize, Collate, shuffle,
240+
device, seed,
241+
num_worker, drop_last,
242+
disposeBatch, disposeDataset)
143243
{
144244
}
145245

0 commit comments

Comments
 (0)