Skip to content

Commit 1781659

Browse files
authored
fix(torchx): better mps support (#1652)
1 parent d0d3665 commit 1781659

File tree

10 files changed

+578
-230
lines changed

10 files changed

+578
-230
lines changed

torchx/c_src/torchx.cpp

Lines changed: 107 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,28 @@
1010
#include "torchx_nif_util.h"
1111
#include <iostream>
1212
#include <numeric>
13+
#include <stdexcept>
1314

1415
namespace torchx {
1516

1617
// Register TorchTensor as a resource type
1718
FINE_RESOURCE(TorchTensor);
1819

20+
// Helper macro to provide better error messages for PyTorch exceptions
21+
#define TORCH_CATCH_ERROR(expr, operation_name) \
22+
try { \
23+
return expr; \
24+
} catch (const c10::Error &e) { \
25+
throw std::runtime_error(std::string(operation_name) + \
26+
" failed: " + e.what()); \
27+
} catch (const std::exception &e) { \
28+
throw std::runtime_error(std::string(operation_name) + \
29+
" failed: " + e.what()); \
30+
} catch (...) { \
31+
throw std::runtime_error(std::string(operation_name) + \
32+
" failed with unknown error"); \
33+
}
34+
1935
// Macro to register both _cpu and _io variants of a function
2036
// Following EXLA's pattern - create wrapper functions
2137
#define REGISTER_TENSOR_NIF(NAME) \
@@ -257,8 +273,11 @@ REGISTER_TENSOR_NIF(to_type);
257273
fine::Ok<fine::ResourcePtr<TorchTensor>>
258274
to_device(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor,
259275
std::tuple<int64_t, int64_t> device_tuple) {
260-
auto device = tuple_to_device(device_tuple);
261-
return tensor_ok(get_tensor(tensor).to(device));
276+
TORCH_CATCH_ERROR(({
277+
auto device = tuple_to_device(device_tuple);
278+
tensor_ok(get_tensor(tensor).to(device));
279+
}),
280+
"Device transfer");
262281
}
263282

264283
REGISTER_TENSOR_NIF(to_device);
@@ -340,15 +359,18 @@ fine::Ok<fine::ResourcePtr<TorchTensor>>
340359
index_put(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> input,
341360
std::vector<fine::ResourcePtr<TorchTensor>> indices,
342361
fine::ResourcePtr<TorchTensor> values, bool accumulate) {
362+
TORCH_CATCH_ERROR(
363+
[&]() {
364+
c10::List<std::optional<at::Tensor>> torch_indices;
365+
for (const auto &idx : indices) {
366+
torch_indices.push_back(get_tensor(idx));
367+
}
343368

344-
c10::List<std::optional<at::Tensor>> torch_indices;
345-
for (const auto &idx : indices) {
346-
torch_indices.push_back(get_tensor(idx));
347-
}
348-
349-
torch::Tensor result = get_tensor(input).clone();
350-
result.index_put_(torch_indices, get_tensor(values), accumulate);
351-
return tensor_ok(result);
369+
torch::Tensor result = get_tensor(input).clone();
370+
result.index_put_(torch_indices, get_tensor(values), accumulate);
371+
return tensor_ok(result);
372+
}(),
373+
"index_put");
352374
}
353375

354376
REGISTER_TENSOR_NIF(index_put);
@@ -713,26 +735,30 @@ REGISTER_TENSOR_NIF(matmul);
713735
fine::Ok<fine::ResourcePtr<TorchTensor>>
714736
pad(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor,
715737
fine::ResourcePtr<TorchTensor> constant, std::vector<int64_t> config) {
716-
return tensor_ok(torch::constant_pad_nd(get_tensor(tensor),
717-
vec_to_array_ref(config),
718-
get_tensor(constant).item()));
738+
TORCH_CATCH_ERROR(tensor_ok(torch::constant_pad_nd(
739+
get_tensor(tensor), vec_to_array_ref(config),
740+
get_tensor(constant).item())),
741+
"Pad operation");
719742
}
720743

721744
REGISTER_TENSOR_NIF(pad);
722745

723746
fine::Ok<fine::ResourcePtr<TorchTensor>>
724747
triangular_solve(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> a,
725748
fine::ResourcePtr<TorchTensor> b, bool transpose, bool upper) {
726-
auto ts_a = get_tensor(a);
727-
if (transpose) {
728-
auto num_dims = ts_a.dim();
729-
ts_a = torch::transpose(ts_a, num_dims - 2, num_dims - 1);
730-
upper = !upper;
731-
}
732-
733-
torch::Tensor result =
734-
torch::linalg_solve_triangular(ts_a, get_tensor(b), upper, true, false);
735-
return tensor_ok(result);
749+
TORCH_CATCH_ERROR(({
750+
auto ts_a = get_tensor(a);
751+
if (transpose) {
752+
auto num_dims = ts_a.dim();
753+
ts_a =
754+
torch::transpose(ts_a, num_dims - 2, num_dims - 1);
755+
upper = !upper;
756+
}
757+
torch::Tensor result = torch::linalg_solve_triangular(
758+
ts_a, get_tensor(b), upper, true, false);
759+
tensor_ok(result);
760+
}),
761+
"Triangular solve");
736762
}
737763

738764
REGISTER_TENSOR_NIF(triangular_solve);
@@ -952,20 +978,28 @@ REGISTER_TENSOR_NIF_ARITY(cholesky, cholesky_2);
952978
fine::Ok<
953979
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
954980
qr_1(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
955-
auto result = torch::linalg_qr(get_tensor(t), "reduced");
956-
return fine::Ok(
957-
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
958-
fine::make_resource<TorchTensor>(std::get<1>(result))));
981+
TORCH_CATCH_ERROR(
982+
({
983+
auto result = torch::linalg_qr(get_tensor(t), "reduced");
984+
fine::Ok(std::make_tuple(
985+
fine::make_resource<TorchTensor>(std::get<0>(result)),
986+
fine::make_resource<TorchTensor>(std::get<1>(result))));
987+
}),
988+
"QR decomposition");
959989
}
960990

961991
fine::Ok<
962992
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
963993
qr_2(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t, bool reduced) {
964-
auto result =
965-
torch::linalg_qr(get_tensor(t), reduced ? "reduced" : "complete");
966-
return fine::Ok(
967-
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
968-
fine::make_resource<TorchTensor>(std::get<1>(result))));
994+
TORCH_CATCH_ERROR(
995+
({
996+
auto result =
997+
torch::linalg_qr(get_tensor(t), reduced ? "reduced" : "complete");
998+
fine::Ok(std::make_tuple(
999+
fine::make_resource<TorchTensor>(std::get<0>(result)),
1000+
fine::make_resource<TorchTensor>(std::get<1>(result))));
1001+
}),
1002+
"QR decomposition");
9691003
}
9701004

9711005
REGISTER_TENSOR_NIF_ARITY(qr, qr_1);
@@ -976,22 +1010,30 @@ fine::Ok<
9761010
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
9771011
fine::ResourcePtr<TorchTensor>>>
9781012
svd_1(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
979-
auto result = torch::linalg_svd(get_tensor(t), true);
980-
return fine::Ok(
981-
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
982-
fine::make_resource<TorchTensor>(std::get<1>(result)),
983-
fine::make_resource<TorchTensor>(std::get<2>(result))));
1013+
TORCH_CATCH_ERROR(
1014+
({
1015+
auto result = torch::linalg_svd(get_tensor(t), true);
1016+
fine::Ok(std::make_tuple(
1017+
fine::make_resource<TorchTensor>(std::get<0>(result)),
1018+
fine::make_resource<TorchTensor>(std::get<1>(result)),
1019+
fine::make_resource<TorchTensor>(std::get<2>(result))));
1020+
}),
1021+
"SVD decomposition");
9841022
}
9851023

9861024
fine::Ok<
9871025
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
9881026
fine::ResourcePtr<TorchTensor>>>
9891027
svd_2(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t, bool full_matrices) {
990-
auto result = torch::linalg_svd(get_tensor(t), full_matrices);
991-
return fine::Ok(
992-
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
993-
fine::make_resource<TorchTensor>(std::get<1>(result)),
994-
fine::make_resource<TorchTensor>(std::get<2>(result))));
1028+
TORCH_CATCH_ERROR(
1029+
({
1030+
auto result = torch::linalg_svd(get_tensor(t), full_matrices);
1031+
fine::Ok(std::make_tuple(
1032+
fine::make_resource<TorchTensor>(std::get<0>(result)),
1033+
fine::make_resource<TorchTensor>(std::get<1>(result)),
1034+
fine::make_resource<TorchTensor>(std::get<2>(result))));
1035+
}),
1036+
"SVD decomposition");
9951037
}
9961038

9971039
REGISTER_TENSOR_NIF_ARITY(svd, svd_1);
@@ -1001,15 +1043,18 @@ fine::Ok<
10011043
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
10021044
fine::ResourcePtr<TorchTensor>>>
10031045
lu(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
1004-
std::tuple<torch::Tensor, torch::Tensor> lu_result =
1005-
torch::linalg_lu_factor(get_tensor(t));
1006-
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> plu =
1007-
torch::lu_unpack(std::get<0>(lu_result), std::get<1>(lu_result));
1008-
1009-
return fine::Ok(
1010-
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(plu)),
1011-
fine::make_resource<TorchTensor>(std::get<1>(plu)),
1012-
fine::make_resource<TorchTensor>(std::get<2>(plu))));
1046+
TORCH_CATCH_ERROR(
1047+
({
1048+
std::tuple<torch::Tensor, torch::Tensor> lu_result =
1049+
torch::linalg_lu_factor(get_tensor(t));
1050+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> plu =
1051+
torch::lu_unpack(std::get<0>(lu_result), std::get<1>(lu_result));
1052+
fine::Ok(std::make_tuple(
1053+
fine::make_resource<TorchTensor>(std::get<0>(plu)),
1054+
fine::make_resource<TorchTensor>(std::get<1>(plu)),
1055+
fine::make_resource<TorchTensor>(std::get<2>(plu))));
1056+
}),
1057+
"LU decomposition");
10131058
}
10141059

10151060
REGISTER_TENSOR_NIF(lu);
@@ -1035,19 +1080,24 @@ REGISTER_TENSOR_NIF(amin);
10351080
fine::Ok<
10361081
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
10371082
eigh(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor) {
1038-
auto result = torch::linalg_eigh(get_tensor(tensor));
1039-
return fine::Ok(
1040-
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
1041-
fine::make_resource<TorchTensor>(std::get<1>(result))));
1083+
TORCH_CATCH_ERROR(
1084+
({
1085+
auto result = torch::linalg_eigh(get_tensor(tensor));
1086+
fine::Ok(std::make_tuple(
1087+
fine::make_resource<TorchTensor>(std::get<0>(result)),
1088+
fine::make_resource<TorchTensor>(std::get<1>(result))));
1089+
}),
1090+
"Eigenvalue decomposition (eigh)");
10421091
}
10431092

10441093
REGISTER_TENSOR_NIF(eigh);
10451094

10461095
fine::Ok<fine::ResourcePtr<TorchTensor>>
10471096
solve(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensorA,
10481097
fine::ResourcePtr<TorchTensor> tensorB) {
1049-
return tensor_ok(
1050-
torch::linalg_solve(get_tensor(tensorA), get_tensor(tensorB)));
1098+
TORCH_CATCH_ERROR(
1099+
tensor_ok(torch::linalg_solve(get_tensor(tensorA), get_tensor(tensorB))),
1100+
"Linear solve");
10511101
}
10521102

10531103
REGISTER_TENSOR_NIF(solve);

torchx/lib/torchx.ex

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,19 @@ defmodule Torchx do
228228
def eye(size, type, device), do: eye(size, size, type, device)
229229
defdevice eye(m, n, type, device)
230230
defdevice from_blob(blob, shape, type, device)
231-
defdevice to_device(tensor, device)
231+
232+
@torch_function {:to_device, 2}
233+
def to_device(tensor, device) do
234+
{[tensor_ref], _current_device} = prepare_tensors!([tensor])
235+
{user_device, index} = normalize_device!(device)
236+
target_device_struct = torch_device!(user_device, index)
237+
238+
case user_device do
239+
:cpu -> Torchx.NIF.to_device_cpu(tensor_ref, target_device_struct)
240+
_ -> Torchx.NIF.to_device_io(tensor_ref, target_device_struct)
241+
end
242+
|> unwrap_tensor!(user_device)
243+
end
232244

233245
## Manipulation
234246

@@ -466,7 +478,9 @@ defmodule Torchx do
466478
ref
467479

468480
{other_dev, ref} when is_tensor(other_dev, ref) ->
469-
raise ArgumentError, "cannot perform operation across devices #{dev} and #{other_dev}"
481+
# Auto-transfer tensor to target device
482+
{^dev, new_ref} = Torchx.to_device({other_dev, ref}, dev)
483+
new_ref
470484

471485
bad_tensor ->
472486
raise ArgumentError, "expected a Torchx tensor, got: #{inspect(bad_tensor)}"
@@ -484,7 +498,9 @@ defmodule Torchx do
484498
{ref, dev}
485499

486500
{dev, ref}, other_dev when is_tensor(dev, ref) ->
487-
raise ArgumentError, "cannot perform operation across devices #{dev} and #{other_dev}"
501+
# Auto-transfer tensor to target device
502+
{^other_dev, new_ref} = Torchx.to_device({dev, ref}, other_dev)
503+
{new_ref, other_dev}
488504

489505
[{dev, ref} | _] = tensors, nil when is_tensor(dev, ref) ->
490506
prepare_tensors_list!(tensors, dev)

0 commit comments

Comments
 (0)