Skip to content

Commit 9a302a1

Browse files
Merge branch 'main' into 1314
2 parents a638018 + cc229b8 commit 9a302a1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+511
-410
lines changed

src/Examples/SequenceToSequence.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,10 @@ protected override void Dispose(bool disposing)
268268
base.Dispose(disposing);
269269
}
270270

271-
protected override Module _to(DeviceType deviceType, int deviceIndex = -1)
271+
protected override Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
272272
{
273273
this.device = new Device(deviceType, deviceIndex);
274-
return base._to(deviceType, deviceIndex);
274+
return base._to(deviceType, deviceIndex, non_blocking);
275275
}
276276
}
277277

src/Native/LibTorchSharp/THSModule.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,29 @@ void THSNN_Module_zero_grad(const NNModule module, bool set_to_none)
2525
(*module)->zero_grad(set_to_none);
2626
}
2727

28-
void THSNN_Module_to_device(NNModule module, int64_t device, int64_t index)
28+
void THSNN_Module_to_device(NNModule module, int64_t device, int64_t index, const bool non_blocking)
2929
{
3030
c10::DeviceType dev = c10::kCPU;
3131
if (device == 1)
3232
dev = c10::kCUDA;
3333
if (device == 13)
3434
dev = c10::kMPS;
35-
(*module)->to(torch::Device(dev, index));
35+
(*module)->to(torch::Device(dev, index), non_blocking);
3636
}
3737

38-
void THSNN_Module_to_dtype(NNModule module, int8_t dtype)
38+
void THSNN_Module_to_dtype(NNModule module, int8_t dtype, const bool non_blocking)
3939
{
40-
(*module)->to((at::ScalarType)dtype);
40+
(*module)->to((at::ScalarType)dtype, non_blocking);
4141
}
4242

43-
void THSNN_Module_to_device_dtype(NNModule module, int8_t dtype, int64_t device, int64_t index)
43+
void THSNN_Module_to_device_dtype(NNModule module, int8_t dtype, int64_t device, int64_t index, const bool non_blocking)
4444
{
4545
c10::DeviceType dev = c10::kCPU;
4646
if (device == 1)
4747
dev = c10::kCUDA;
4848
if (device == 13)
4949
dev = c10::kMPS;
50-
(*module)->to(torch::Device(dev, index), (at::ScalarType)dtype);
50+
(*module)->to(torch::Device(dev, index), (at::ScalarType)dtype, non_blocking);
5151
}
5252

5353
void THSNN_Module_dispose(const NNModule module)

src/Native/LibTorchSharp/THSNN.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ EXPORT_API(void) THSNN_Module_register_buffer(const NNModule module, cons
2828
EXPORT_API(void) THSNN_Module_register_parameter(const NNModule module, const char* name, const Tensor tensor, bool requires_grad);
2929
EXPORT_API(void) THSNN_Module_register_module(const NNModule module, const char* name, const NNModule submodule);
3030
EXPORT_API(void) THSNN_Module_dispose(const NNModule module);
31-
EXPORT_API(void) THSNN_Module_to_device(NNModule module, int64_t device, int64_t index);
32-
EXPORT_API(void) THSNN_Module_to_dtype(NNModule module, int8_t dtype);
33-
EXPORT_API(void) THSNN_Module_to_device_dtype(NNModule module, int8_t dtype, int64_t device, int64_t index);
31+
EXPORT_API(void) THSNN_Module_to_device(NNModule module, int64_t device, int64_t index, const bool non_blocking);
32+
EXPORT_API(void) THSNN_Module_to_dtype(NNModule module, int8_t dtype, const bool non_blocking);
33+
EXPORT_API(void) THSNN_Module_to_device_dtype(NNModule module, int8_t dtype, int64_t device, int64_t index, const bool non_blocking);
3434

3535
EXPORT_API(void) THSNN_AnyModule_dispose(const NNAnyModule module);
3636
//EXPORT_API(NNModule) THSNN_AnyModule_get(const NNAnyModule module);

src/Native/LibTorchSharp/THSTensor.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,24 +1815,24 @@ void THSTensor_set_(Tensor tensor, const Tensor source)
18151815
CATCH(tensor->set_(*source););
18161816
}
18171817

1818-
Tensor THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy)
1818+
Tensor THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy, const bool non_blocking)
18191819
{
18201820
CATCH_RETURN_Tensor(
18211821
auto device = c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index);
1822-
res = ResultTensor(tensor->to(device, false, copy));
1822+
res = ResultTensor(tensor->to(device, non_blocking, copy));
18231823
);
18241824
}
18251825

1826-
Tensor THSTensor_to_type(const Tensor tensor, int8_t scalar_type, const bool copy)
1826+
Tensor THSTensor_to_type(const Tensor tensor, int8_t scalar_type, const bool copy, const bool non_blocking)
18271827
{
1828-
CATCH_TENSOR(tensor->to(at::ScalarType(scalar_type), false, copy));
1828+
CATCH_TENSOR(tensor->to(at::ScalarType(scalar_type), non_blocking, copy));
18291829
}
18301830

1831-
Tensor THSTensor_to_type_and_device(const Tensor tensor, int8_t scalar_type, const int device_type, const int device_index, const bool copy)
1831+
Tensor THSTensor_to_type_and_device(const Tensor tensor, int8_t scalar_type, const int device_type, const int device_index, const bool copy, const bool non_blocking)
18321832
{
18331833
CATCH_RETURN_Tensor(
18341834
auto device = c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index);
1835-
res = ResultTensor(tensor->to(device, at::ScalarType(scalar_type), false, copy));
1835+
res = ResultTensor(tensor->to(device, at::ScalarType(scalar_type), non_blocking, copy));
18361836
);
18371837
}
18381838

src/Native/LibTorchSharp/THSTensor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,11 +1327,11 @@ EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_dx(const Tensor y, const doubl
13271327

13281328
EXPORT_API(Tensor) THSTensor_to_dense(Tensor tensor);
13291329

1330-
EXPORT_API(Tensor) THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy);
1330+
EXPORT_API(Tensor) THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy, const bool non_blocking);
13311331

1332-
EXPORT_API(Tensor) THSTensor_to_type(const Tensor tensor, int8_t scalar_type, const bool copy);
1332+
EXPORT_API(Tensor) THSTensor_to_type(const Tensor tensor, int8_t scalar_type, const bool copy, const bool non_blocking);
13331333

1334-
EXPORT_API(Tensor) THSTensor_to_type_and_device(const Tensor tensor, int8_t scalar_type, const int device_type, const int device_index, const bool copy);
1334+
EXPORT_API(Tensor) THSTensor_to_type_and_device(const Tensor tensor, int8_t scalar_type, const int device_type, const int device_index, const bool copy, const bool non_blocking);
13351335

13361336
EXPORT_API(void) THSTensor_topk(const Tensor tensor, Tensor* (*allocator)(size_t length), const int k, const int64_t dim, const bool largest, const bool sorted);
13371337

src/TorchSharp/JIT/ScriptModule.cs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ public override bool training {
143143
}
144144
}
145145

146-
protected internal override nn.Module _to(Device device, ScalarType dtype)
146+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking)
147147
{
148148
if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); };
149149

@@ -154,8 +154,8 @@ protected internal override nn.Module _to(Device device, ScalarType dtype)
154154
THSJIT_Module_to_device_dtype(handle, (sbyte)dtype, (int)device.type, device.index);
155155
CheckForErrors();
156156

157-
_toEpilog(device, dtype);
158-
_toScriptEpilog(device, dtype);
157+
_toEpilog(device, dtype, non_blocking);
158+
_toScriptEpilog(device, dtype, non_blocking);
159159
return this;
160160
}
161161

@@ -164,8 +164,12 @@ protected internal override nn.Module _to(Device device, ScalarType dtype)
164164
/// </summary>
165165
/// <param name="deviceType">The device type, e.g. 'CPU' or 'CUDA'.</param>
166166
/// <param name="deviceIndex">The optional device index.</param>
167+
/// <param name="non_blocking">
168+
/// When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible,
169+
/// e.g., moving CPU Tensors with pinned memory to CUDA devices.
170+
/// </param>
167171
/// <returns></returns>
168-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1)
172+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
169173
{
170174
if (deviceType != DeviceType.CUDA) deviceIndex = -1;
171175

@@ -177,8 +181,8 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex
177181
THSJIT_Module_to_device(handle, (int)deviceType, deviceIndex);
178182
CheckForErrors();
179183

180-
_toEpilog(deviceType, deviceIndex);
181-
_toScriptEpilog(deviceType, deviceIndex);
184+
_toEpilog(deviceType, deviceIndex, non_blocking);
185+
_toScriptEpilog(deviceType, deviceIndex, non_blocking);
182186
}
183187

184188
Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);
@@ -193,38 +197,38 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex
193197
/// Convert the parameters and buffers.
194198
/// </summary>
195199
/// <returns></returns>
196-
protected internal override nn.Module _to(ScalarType dtype)
200+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking)
197201
{
198202
THSJIT_Module_to_dtype(handle, (sbyte)dtype);
199203
CheckForErrors();
200204

201-
_toEpilog(dtype);
202-
_toScriptEpilog(dtype);
205+
_toEpilog(dtype, non_blocking);
206+
_toScriptEpilog(dtype, non_blocking);
203207

204208
return this;
205209
}
206210

207-
protected void _toScriptEpilog(ScalarType dtype)
211+
protected void _toScriptEpilog(ScalarType dtype, bool non_blocking)
208212
{
209-
_toScriptEpilog(dtype, null);
213+
_toScriptEpilog(dtype, null, non_blocking);
210214
}
211215

212-
protected void _toScriptEpilog(Device device, ScalarType dtype)
216+
protected void _toScriptEpilog(Device device, ScalarType dtype, bool non_blocking)
213217
{
214-
_toScriptEpilog(dtype, device);
218+
_toScriptEpilog(dtype, device, non_blocking);
215219
}
216220

217-
protected void _toScriptEpilog(DeviceType deviceType, int deviceIndex)
221+
protected void _toScriptEpilog(DeviceType deviceType, int deviceIndex, bool non_blocking)
218222
{
219-
_toScriptEpilog(null, new Device(deviceType, deviceIndex));
223+
_toScriptEpilog(null, new Device(deviceType, deviceIndex), non_blocking);
220224
}
221225

222-
private void _toScriptEpilog(ScalarType? dtype, Device device)
226+
private void _toScriptEpilog(ScalarType? dtype, Device device, bool non_blocking)
223227
{
224228
foreach (var (name, buffer) in named_attributes(recurse: false)) {
225229
if (name is null || !buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
226230

227-
set_attribute(name, buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true));
231+
set_attribute(name, buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true, non_blocking: non_blocking));
228232
}
229233
}
230234

src/TorchSharp/NN/Activation/CELU.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ public override string GetName()
3030

3131
// Rather than spending cycles only to discover that this module has neither
3232
// parameters nor buffers, just shortcut the move completely.
33-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
34-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
35-
protected internal override nn.Module _to(ScalarType dtype) => this;
33+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
34+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
35+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
3636
}
3737
}
3838

src/TorchSharp/NN/Activation/ELU.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ public override string GetName()
2828
return typeof(ELU).Name;
2929
}
3030

31-
// Rather than spending cycles only to discover that this module has neither
31+
// Rather than spending cycles only to discover that this module has neither
3232
// parameters nor buffers, just shortcut the move completely.
33-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
34-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
35-
protected internal override nn.Module _to(ScalarType dtype) => this;
33+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
34+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
35+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
3636
}
3737
}
3838

src/TorchSharp/NN/Activation/GELU.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ public override string GetName()
2828
return typeof(GELU).Name;
2929
}
3030

31-
// Rather than spending cycles only to discover that this module has neither
31+
// Rather than spending cycles only to discover that this module has neither
3232
// parameters nor buffers, just shortcut the move completely.
33-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
34-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
35-
protected internal override nn.Module _to(ScalarType dtype) => this;
33+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
34+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
35+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
3636
}
3737
}
3838

src/TorchSharp/NN/Activation/GLU.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ public override string GetName()
2828
return typeof(GLU).Name;
2929
}
3030

31-
// Rather than spending cycles only to discover that this module has neither
31+
// Rather than spending cycles only to discover that this module has neither
3232
// parameters nor buffers, just shortcut the move completely.
33-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
34-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
35-
protected internal override nn.Module _to(ScalarType dtype) => this;
33+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
34+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
35+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
3636
}
3737
}
3838

0 commit comments

Comments
 (0)