Skip to content

Commit 3e8648e

Browse files
Update
[ghstack-poisoned]
2 parents 404a1b8 + 3229b92 commit 3e8648e

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

backends/apple/metal/runtime/metal_backend.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ class ET_EXPERIMENTAL MetalBackend final
273273
n_outputs,
274274
args.size());
275275

276+
int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13
277+
276278
// NOTE: ExecutorTorch tensors are always on CPU/host memory
277279
// We need to create GPU copies for Metal kernel execution
278280
std::vector<AOTITensorHandle> gpu_inputs(
@@ -308,7 +310,7 @@ class ET_EXPERIMENTAL MetalBackend final
308310
sizes_vec.data(),
309311
nullptr, // use default strides
310312
static_cast<int32_t>(scalar_type),
311-
2, // device_type = mps
313+
mps_device_type, // device_type = mps
312314
0, // device_index = 0
313315
&gpu_input_handle);
314316

@@ -386,7 +388,7 @@ class ET_EXPERIMENTAL MetalBackend final
386388
sizes_vec.data(),
387389
nullptr, // use default strides
388390
static_cast<int32_t>(scalar_type),
389-
2, // device_type = mps
391+
mps_device_type, // device_type = mps
390392
0, // device_index = 0
391393
&gpu_output_handle);
392394

backends/apple/metal/runtime/shims/memory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ AOTITorchError aoti_torch_empty_strided(
143143
dtype);
144144
int64_t nbytes = numel * element_size;
145145

146-
int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13
146+
int32_t mps_device_type = aoti_torch_device_type_mps(); // Returns 13
147147
if (device_type == mps_device_type) {
148148
ptr = metal_allocate_buffer(nbytes);
149149
if (!ptr) {

0 commit comments

Comments
 (0)