1010#include " torchx_nif_util.h"
1111#include < iostream>
1212#include < numeric>
13+ #include < stdexcept>
1314
1415namespace torchx {
1516
1617// Register TorchTensor as a resource type
1718FINE_RESOURCE (TorchTensor);
1819
20+ // Helper macro to provide better error messages for PyTorch exceptions
21+ #define TORCH_CATCH_ERROR (expr, operation_name ) \
22+ try { \
23+ return expr; \
24+ } catch (const c10::Error &e) { \
25+ throw std::runtime_error (std::string (operation_name) + \
26+ " failed: " + e.what ()); \
27+ } catch (const std::exception &e) { \
28+ throw std::runtime_error (std::string (operation_name) + \
29+ " failed: " + e.what ()); \
30+ } catch (...) { \
31+ throw std::runtime_error (std::string (operation_name) + \
32+ " failed with unknown error" ); \
33+ }
34+
1935// Macro to register both _cpu and _io variants of a function
2036// Following EXLA's pattern - create wrapper functions
2137#define REGISTER_TENSOR_NIF (NAME ) \
@@ -257,8 +273,11 @@ REGISTER_TENSOR_NIF(to_type);
257273fine::Ok<fine::ResourcePtr<TorchTensor>>
258274to_device (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor,
259275 std::tuple<int64_t , int64_t > device_tuple) {
260- auto device = tuple_to_device (device_tuple);
261- return tensor_ok (get_tensor (tensor).to (device));
276+ TORCH_CATCH_ERROR (({
277+ auto device = tuple_to_device (device_tuple);
278+ tensor_ok (get_tensor (tensor).to (device));
279+ }),
280+ " Device transfer" );
262281}
263282
264283REGISTER_TENSOR_NIF (to_device);
@@ -340,15 +359,18 @@ fine::Ok<fine::ResourcePtr<TorchTensor>>
340359index_put (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> input,
341360 std::vector<fine::ResourcePtr<TorchTensor>> indices,
342361 fine::ResourcePtr<TorchTensor> values, bool accumulate) {
362+ TORCH_CATCH_ERROR (
363+ [&]() {
364+ c10::List<std::optional<at::Tensor>> torch_indices;
365+ for (const auto &idx : indices) {
366+ torch_indices.push_back (get_tensor (idx));
367+ }
343368
344- c10::List<std::optional<at::Tensor>> torch_indices;
345- for (const auto &idx : indices) {
346- torch_indices.push_back (get_tensor (idx));
347- }
348-
349- torch::Tensor result = get_tensor (input).clone ();
350- result.index_put_ (torch_indices, get_tensor (values), accumulate);
351- return tensor_ok (result);
369+ torch::Tensor result = get_tensor (input).clone ();
370+ result.index_put_ (torch_indices, get_tensor (values), accumulate);
371+ return tensor_ok (result);
372+ }(),
373+ " index_put" );
352374}
353375
354376REGISTER_TENSOR_NIF (index_put);
@@ -713,26 +735,30 @@ REGISTER_TENSOR_NIF(matmul);
713735fine::Ok<fine::ResourcePtr<TorchTensor>>
714736pad (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor,
715737 fine::ResourcePtr<TorchTensor> constant, std::vector<int64_t > config) {
716- return tensor_ok (torch::constant_pad_nd (get_tensor (tensor),
717- vec_to_array_ref (config),
718- get_tensor (constant).item ()));
738+ TORCH_CATCH_ERROR (tensor_ok (torch::constant_pad_nd (
739+ get_tensor (tensor), vec_to_array_ref (config),
740+ get_tensor (constant).item ())),
741+ " Pad operation" );
719742}
720743
721744REGISTER_TENSOR_NIF (pad);
722745
723746fine::Ok<fine::ResourcePtr<TorchTensor>>
724747triangular_solve (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> a,
725748 fine::ResourcePtr<TorchTensor> b, bool transpose, bool upper) {
726- auto ts_a = get_tensor (a);
727- if (transpose) {
728- auto num_dims = ts_a.dim ();
729- ts_a = torch::transpose (ts_a, num_dims - 2 , num_dims - 1 );
730- upper = !upper;
731- }
732-
733- torch::Tensor result =
734- torch::linalg_solve_triangular (ts_a, get_tensor (b), upper, true , false );
735- return tensor_ok (result);
749+ TORCH_CATCH_ERROR (({
750+ auto ts_a = get_tensor (a);
751+ if (transpose) {
752+ auto num_dims = ts_a.dim ();
753+ ts_a =
754+ torch::transpose (ts_a, num_dims - 2 , num_dims - 1 );
755+ upper = !upper;
756+ }
757+ torch::Tensor result = torch::linalg_solve_triangular (
758+ ts_a, get_tensor (b), upper, true , false );
759+ tensor_ok (result);
760+ }),
761+ " Triangular solve" );
736762}
737763
738764REGISTER_TENSOR_NIF (triangular_solve);
@@ -952,20 +978,28 @@ REGISTER_TENSOR_NIF_ARITY(cholesky, cholesky_2);
952978fine::Ok<
953979 std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
954980qr_1 (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
955- auto result = torch::linalg_qr (get_tensor (t), " reduced" );
956- return fine::Ok (
957- std::make_tuple (fine::make_resource<TorchTensor>(std::get<0 >(result)),
958- fine::make_resource<TorchTensor>(std::get<1 >(result))));
981+ TORCH_CATCH_ERROR (
982+ ({
983+ auto result = torch::linalg_qr (get_tensor (t), " reduced" );
984+ fine::Ok (std::make_tuple (
985+ fine::make_resource<TorchTensor>(std::get<0 >(result)),
986+ fine::make_resource<TorchTensor>(std::get<1 >(result))));
987+ }),
988+ " QR decomposition" );
959989}
960990
961991fine::Ok<
962992 std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
963993qr_2 (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t, bool reduced) {
964- auto result =
965- torch::linalg_qr (get_tensor (t), reduced ? " reduced" : " complete" );
966- return fine::Ok (
967- std::make_tuple (fine::make_resource<TorchTensor>(std::get<0 >(result)),
968- fine::make_resource<TorchTensor>(std::get<1 >(result))));
994+ TORCH_CATCH_ERROR (
995+ ({
996+ auto result =
997+ torch::linalg_qr (get_tensor (t), reduced ? " reduced" : " complete" );
998+ fine::Ok (std::make_tuple (
999+ fine::make_resource<TorchTensor>(std::get<0 >(result)),
1000+ fine::make_resource<TorchTensor>(std::get<1 >(result))));
1001+ }),
1002+ " QR decomposition" );
9691003}
9701004
9711005REGISTER_TENSOR_NIF_ARITY (qr, qr_1);
@@ -976,22 +1010,30 @@ fine::Ok<
9761010 std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
9771011 fine::ResourcePtr<TorchTensor>>>
9781012svd_1 (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
979- auto result = torch::linalg_svd (get_tensor (t), true );
980- return fine::Ok (
981- std::make_tuple (fine::make_resource<TorchTensor>(std::get<0 >(result)),
982- fine::make_resource<TorchTensor>(std::get<1 >(result)),
983- fine::make_resource<TorchTensor>(std::get<2 >(result))));
1013+ TORCH_CATCH_ERROR (
1014+ ({
1015+ auto result = torch::linalg_svd (get_tensor (t), true );
1016+ fine::Ok (std::make_tuple (
1017+ fine::make_resource<TorchTensor>(std::get<0 >(result)),
1018+ fine::make_resource<TorchTensor>(std::get<1 >(result)),
1019+ fine::make_resource<TorchTensor>(std::get<2 >(result))));
1020+ }),
1021+ " SVD decomposition" );
9841022}
9851023
9861024fine::Ok<
9871025 std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
9881026 fine::ResourcePtr<TorchTensor>>>
9891027svd_2 (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t, bool full_matrices) {
990- auto result = torch::linalg_svd (get_tensor (t), full_matrices);
991- return fine::Ok (
992- std::make_tuple (fine::make_resource<TorchTensor>(std::get<0 >(result)),
993- fine::make_resource<TorchTensor>(std::get<1 >(result)),
994- fine::make_resource<TorchTensor>(std::get<2 >(result))));
1028+ TORCH_CATCH_ERROR (
1029+ ({
1030+ auto result = torch::linalg_svd (get_tensor (t), full_matrices);
1031+ fine::Ok (std::make_tuple (
1032+ fine::make_resource<TorchTensor>(std::get<0 >(result)),
1033+ fine::make_resource<TorchTensor>(std::get<1 >(result)),
1034+ fine::make_resource<TorchTensor>(std::get<2 >(result))));
1035+ }),
1036+ " SVD decomposition" );
9951037}
9961038
9971039REGISTER_TENSOR_NIF_ARITY (svd, svd_1);
@@ -1001,15 +1043,18 @@ fine::Ok<
10011043 std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
10021044 fine::ResourcePtr<TorchTensor>>>
10031045lu (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
1004- std::tuple<torch::Tensor, torch::Tensor> lu_result =
1005- torch::linalg_lu_factor (get_tensor (t));
1006- std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> plu =
1007- torch::lu_unpack (std::get<0 >(lu_result), std::get<1 >(lu_result));
1008-
1009- return fine::Ok (
1010- std::make_tuple (fine::make_resource<TorchTensor>(std::get<0 >(plu)),
1011- fine::make_resource<TorchTensor>(std::get<1 >(plu)),
1012- fine::make_resource<TorchTensor>(std::get<2 >(plu))));
1046+ TORCH_CATCH_ERROR (
1047+ ({
1048+ std::tuple<torch::Tensor, torch::Tensor> lu_result =
1049+ torch::linalg_lu_factor (get_tensor (t));
1050+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> plu =
1051+ torch::lu_unpack (std::get<0 >(lu_result), std::get<1 >(lu_result));
1052+ fine::Ok (std::make_tuple (
1053+ fine::make_resource<TorchTensor>(std::get<0 >(plu)),
1054+ fine::make_resource<TorchTensor>(std::get<1 >(plu)),
1055+ fine::make_resource<TorchTensor>(std::get<2 >(plu))));
1056+ }),
1057+ " LU decomposition" );
10131058}
10141059
10151060REGISTER_TENSOR_NIF (lu);
@@ -1035,19 +1080,24 @@ REGISTER_TENSOR_NIF(amin);
10351080fine::Ok<
10361081 std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
10371082eigh (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor) {
1038- auto result = torch::linalg_eigh (get_tensor (tensor));
1039- return fine::Ok (
1040- std::make_tuple (fine::make_resource<TorchTensor>(std::get<0 >(result)),
1041- fine::make_resource<TorchTensor>(std::get<1 >(result))));
1083+ TORCH_CATCH_ERROR (
1084+ ({
1085+ auto result = torch::linalg_eigh (get_tensor (tensor));
1086+ fine::Ok (std::make_tuple (
1087+ fine::make_resource<TorchTensor>(std::get<0 >(result)),
1088+ fine::make_resource<TorchTensor>(std::get<1 >(result))));
1089+ }),
1090+ " Eigenvalue decomposition (eigh)" );
10421091}
10431092
10441093REGISTER_TENSOR_NIF (eigh);
10451094
10461095fine::Ok<fine::ResourcePtr<TorchTensor>>
10471096solve (ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensorA,
10481097 fine::ResourcePtr<TorchTensor> tensorB) {
1049- return tensor_ok (
1050- torch::linalg_solve (get_tensor (tensorA), get_tensor (tensorB)));
1098+ TORCH_CATCH_ERROR (
1099+ tensor_ok (torch::linalg_solve (get_tensor (tensorA), get_tensor (tensorB))),
1100+ " Linear solve" );
10511101}
10521102
10531103REGISTER_TENSOR_NIF (solve);
0 commit comments