File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
backends/apple/metal/runtime Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -273,6 +273,8 @@ class ET_EXPERIMENTAL MetalBackend final
273273 n_outputs,
274274 args.size ());
275275
276+ int32_t mps_device_type = aoti_torch_device_type_mps (); // Returns 13
277+
276278 // NOTE: ExecutorTorch tensors are always on CPU/host memory
277279 // We need to create GPU copies for Metal kernel execution
278280 std::vector<AOTITensorHandle> gpu_inputs (
@@ -308,7 +310,7 @@ class ET_EXPERIMENTAL MetalBackend final
308310 sizes_vec.data (),
309311 nullptr , // use default strides
310312 static_cast <int32_t >(scalar_type),
311- 2 , // device_type = mps
313+ mps_device_type , // device_type = mps
312314 0 , // device_index = 0
313315 &gpu_input_handle);
314316
@@ -386,7 +388,7 @@ class ET_EXPERIMENTAL MetalBackend final
386388 sizes_vec.data (),
387389 nullptr , // use default strides
388390 static_cast <int32_t >(scalar_type),
389- 2 , // device_type = mps
391+ mps_device_type , // device_type = mps
390392 0 , // device_index = 0
391393 &gpu_output_handle);
392394
Original file line number Diff line number Diff line change @@ -143,7 +143,7 @@ AOTITorchError aoti_torch_empty_strided(
143143 dtype);
144144 int64_t nbytes = numel * element_size;
145145
146- int32_t mps_device_type = aoti_torch_device_type_mps (); // Returns 13
146+ int32_t mps_device_type = aoti_torch_device_type_mps (); // Returns 13
147147 if (device_type == mps_device_type) {
148148 ptr = metal_allocate_buffer (nbytes);
149149 if (!ptr) {
You can’t perform that action at this time.
0 commit comments