Skip to content

Commit a566ae1

Browse files
committed
eckit::linalg::dense::LinearAlgebraTorch, eckit::linalg::sparse::LinearAlgebraTorch (1) single place for backend device name logic (detail::Torch), (2) const device/scalar type at construction, (3) limit MPS device (Apple) to dense functionality and single precision (current status)
1 parent ccb0c52 commit a566ae1

File tree

6 files changed

+94
-125
lines changed

6 files changed

+94
-125
lines changed

src/eckit/linalg/dense/LinearAlgebraTorch.cc

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,33 @@
1212
#include "eckit/linalg/dense/LinearAlgebraTorch.h"
1313

1414
#include <cstring>
15-
#include <ostream>
1615

1716
#include "eckit/exception/Exceptions.h"
1817
#include "eckit/linalg/Matrix.h"
1918
#include "eckit/linalg/Vector.h"
20-
2119
#include "eckit/linalg/detail/Torch.h"
2220

2321

2422
namespace eckit::linalg::dense {
2523

2624

27-
static const LinearAlgebraTorch LA_TORCH_CPU_1("torch");
28-
static const LinearAlgebraTorch LA_TORCH_CPU_2("torch-cpu");
29-
static const LinearAlgebraTorch LA_TORCH_CUDA("torch-cuda");
30-
static const LinearAlgebraTorch LA_TORCH_HIP("torch-hip");
31-
static const LinearAlgebraTorch LA_TORCH_MPS("torch-mps");
32-
static const LinearAlgebraTorch LA_TORCH_XPU("torch-xpu");
33-
static const LinearAlgebraTorch LA_TORCH_XLA("torch-xla");
34-
static const LinearAlgebraTorch LA_TORCH_META("torch-meta");
35-
36-
37-
using detail::get_torch_device;
38-
using detail::make_torch_dense_tensor;
39-
using detail::torch_tensor_transpose;
25+
static const LinearAlgebraTorch LA_TORCH_CPU_1("torch", torch::DeviceType::CPU);
26+
static const LinearAlgebraTorch LA_TORCH_CPU_2("torch-cpu", torch::DeviceType::CPU);
27+
static const LinearAlgebraTorch LA_TORCH_CUDA("torch-cuda", torch::DeviceType::CUDA);
28+
static const LinearAlgebraTorch LA_TORCH_HIP("torch-hip", torch::DeviceType::HIP);
29+
static const LinearAlgebraTorch LA_TORCH_MPS("torch-mps", torch::DeviceType::MPS, torch::kFloat32);
30+
static const LinearAlgebraTorch LA_TORCH_XPU("torch-xpu", torch::DeviceType::XPU);
31+
static const LinearAlgebraTorch LA_TORCH_XLA("torch-xla", torch::DeviceType::XLA);
32+
static const LinearAlgebraTorch LA_TORCH_META("torch-meta", torch::DeviceType::Meta);
4033

4134

4235
Scalar LinearAlgebraTorch::dot(const Vector& x, const Vector& y) const {
4336
ASSERT(x.size() == y.size());
4437

45-
auto x_tensor = make_torch_dense_tensor(x, get_torch_device(name()));
46-
auto y_tensor = make_torch_dense_tensor(y, get_torch_device(name()));
38+
auto x_tensor = make_dense_tensor(x);
39+
auto y_tensor = make_dense_tensor(y);
4740

48-
return torch::dot(x_tensor, y_tensor).to(torch::kCPU).item<Scalar>();
41+
return torch::dot(x_tensor, y_tensor).to(torch::DeviceType::CPU, torch::kFloat64).item<Scalar>();
4942
}
5043

5144

@@ -54,9 +47,9 @@ void LinearAlgebraTorch::gemv(const Matrix& A, const Vector& x, Vector& y) const
5447
ASSERT(A.rows() == y.rows());
5548

5649
// multiplication
57-
auto A_tensor = make_torch_dense_tensor(A, get_torch_device(name()));
58-
auto x_tensor = make_torch_dense_tensor(x, get_torch_device(name()));
59-
auto y_tensor = torch::matmul(A_tensor, x_tensor).to(torch::kCPU).contiguous();
50+
auto A_tensor = make_dense_tensor(A);
51+
auto x_tensor = make_dense_tensor(x);
52+
auto y_tensor = tensor_to_host(torch::matmul(A_tensor, x_tensor));
6053

6154
// assignment
6255
std::memcpy(y.data(), y_tensor.data_ptr<Scalar>(), y.rows() * sizeof(Scalar));
@@ -69,18 +62,13 @@ void LinearAlgebraTorch::gemm(const Matrix& A, const Matrix& X, Matrix& Y) const
6962
ASSERT(X.cols() == Y.cols());
7063

7164
// multiplication and conversion from column-major to row-major (and back)
72-
auto A_tensor = make_torch_dense_tensor(A, get_torch_device(name()));
73-
auto X_tensor = make_torch_dense_tensor(X, get_torch_device(name()));
74-
auto Y_tensor = torch_tensor_transpose(torch::matmul(A_tensor, X_tensor)).to(torch::kCPU).contiguous();
65+
auto A_tensor = make_dense_tensor(A);
66+
auto X_tensor = make_dense_tensor(X);
67+
auto Y_tensor = tensor_transpose(tensor_to_host(torch::matmul(A_tensor, X_tensor)));
7568

7669
// assignment
7770
std::memcpy(Y.data(), Y_tensor.data_ptr<Scalar>(), Y.size() * sizeof(Scalar));
7871
}
7972

8073

81-
void LinearAlgebraTorch::print(std::ostream& out) const {
82-
out << "LinearAlgebraTorch[]";
83-
}
84-
85-
8674
} // namespace eckit::linalg::dense

src/eckit/linalg/dense/LinearAlgebraTorch.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,21 @@
1212
#pragma once
1313

1414
#include "eckit/linalg/LinearAlgebraDense.h"
15+
#include "eckit/linalg/detail/Torch.h"
1516

1617

1718
namespace eckit::linalg::dense {
1819

19-
struct LinearAlgebraTorch final : public LinearAlgebraDense {
20-
LinearAlgebraTorch() = default;
21-
LinearAlgebraTorch(const std::string& name) : LinearAlgebraDense(name) {}
20+
21+
struct LinearAlgebraTorch final : public LinearAlgebraDense, detail::Torch {
22+
LinearAlgebraTorch(const std::string& name, torch::DeviceType device, torch::ScalarType scalar = torch::kFloat64) :
23+
LinearAlgebraDense(name), Torch(device, scalar) {}
2224

2325
Scalar dot(const Vector& x, const Vector& y) const override;
2426
void gemv(const Matrix& A, const Vector& x, Vector& y) const override;
2527
void gemm(const Matrix& A, const Matrix& X, Matrix& Y) const override;
26-
void print(std::ostream&) const override;
28+
void print(std::ostream& os) const override { Torch::print(os); }
2729
};
2830

31+
2932
} // namespace eckit::linalg::dense

src/eckit/linalg/detail/Torch.cc

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
#include "eckit/linalg/detail/Torch.h"
1313

14-
#include <map>
14+
#include <ostream>
1515
#include <type_traits>
1616

17-
#include "eckit/exception/Exceptions.h"
1817
#include "eckit/linalg/Matrix.h"
1918
#include "eckit/linalg/SparseMatrix.h"
2019
#include "eckit/linalg/Vector.h"
@@ -27,66 +26,48 @@ static_assert(std::is_same<int32_t, Index>::value, "Index type mismatch");
2726
static_assert(std::is_same<double, Scalar>::value, "Scalar type mismatch");
2827

2928

30-
torch::DeviceType get_torch_device(const std::string& name) {
31-
static const auto device = [&name]() {
32-
const std::map<std::string, torch::DeviceType> types{
33-
{"cpu", torch::DeviceType::CPU}, //
34-
{"cuda", torch::DeviceType::CUDA}, //
35-
{"hip", torch::DeviceType::HIP}, //
36-
{"mps", torch::DeviceType::MPS}, //
37-
{"xpu", torch::DeviceType::XPU}, //
38-
{"xla", torch::DeviceType::XLA}, //
39-
{"meta", torch::DeviceType::Meta}, //
40-
};
41-
42-
const auto sep = name.find_first_of('-');
43-
if (sep == std::string::npos) {
44-
return torch::DeviceType::CPU;
45-
}
46-
47-
if (auto it = types.find(name.substr(sep + 1)); it != types.end()) {
48-
return it->second;
49-
}
50-
51-
throw eckit::UserError("Unknown torch device: " + name);
52-
}();
53-
54-
return device;
29+
torch::Tensor Torch::tensor_transpose(const torch::Tensor& tensor) const {
30+
return tensor.transpose(0, 1).contiguous();
5531
}
5632

5733

58-
torch::Tensor torch_tensor_transpose(const torch::Tensor& tensor) {
59-
return tensor.transpose(0, 1).contiguous();
34+
torch::Tensor Torch::tensor_to_host(const torch::Tensor& tensor) const {
35+
return tensor.to(torch::DeviceType::CPU, torch::kFloat64).contiguous(); // reverse MPS float32 (if applicable)
6036
}
6137

6238

63-
torch::Tensor make_torch_dense_tensor(const Matrix& A, torch::DeviceType device) {
39+
torch::Tensor Torch::make_dense_tensor(const Matrix& A) const {
6440
auto Ni = static_cast<int64_t>(A.cols());
6541
auto Nj = static_cast<int64_t>(A.rows());
6642

67-
return torch_tensor_transpose(
68-
torch::from_blob(const_cast<Scalar*>(A.data()), {Ni, Nj}, torch::kFloat64).to(device));
43+
return tensor_transpose(
44+
torch::from_blob(const_cast<Scalar*>(A.data()), {Ni, Nj}, torch::kFloat64).to(device_, scalar_));
6945
}
7046

7147

72-
torch::Tensor make_torch_dense_tensor(const Vector& V, torch::DeviceType device) {
48+
torch::Tensor Torch::make_dense_tensor(const Vector& V) const {
7349
auto Ni = static_cast<int64_t>(V.size());
7450

75-
return torch::from_blob(const_cast<Scalar*>(V.data()), {Ni}, torch::kFloat64).to(device);
51+
return torch::from_blob(const_cast<Scalar*>(V.data()), {Ni}, torch::kFloat64).to(device_, scalar_);
7652
}
7753

7854

79-
torch::Tensor make_torch_sparse_csr(const SparseMatrix& A, torch::DeviceType device) {
55+
torch::Tensor Torch::make_sparse_csr_tensor(const SparseMatrix& A) const {
8056
auto Ni = static_cast<int64_t>(A.rows());
8157
auto Nj = static_cast<int64_t>(A.cols());
8258
auto Nz = static_cast<int64_t>(A.nonZeros());
8359

84-
auto ia = torch::from_blob(const_cast<Index*>(A.outer()), {Ni + 1}, torch::kInt32).to(device, torch::kInt64);
85-
auto ja = torch::from_blob(const_cast<Index*>(A.inner()), {Nz}, torch::kInt32).to(device, torch::kInt64);
86-
auto a = torch::from_blob(const_cast<Scalar*>(A.data()), {Nz}, torch::kFloat64).to(device);
60+
auto ia = torch::from_blob(const_cast<Index*>(A.outer()), {Ni + 1}, torch::kInt32).to(device_, torch::kInt64);
61+
auto ja = torch::from_blob(const_cast<Index*>(A.inner()), {Nz}, torch::kInt32).to(device_, torch::kInt64);
62+
auto a = torch::from_blob(const_cast<Scalar*>(A.data()), {Nz}, torch::kFloat64).to(device_, scalar_);
63+
64+
return torch::sparse_csr_tensor(ia, ja, a, {Ni, Nj},
65+
torch::TensorOptions().dtype(scalar_).device(device_).layout(torch::kSparseCsr));
66+
}
67+
8768

88-
return torch::sparse_csr_tensor(
89-
ia, ja, a, {Ni, Nj}, torch::TensorOptions().dtype(torch::kFloat64).device(device).layout(torch::kSparseCsr));
69+
void Torch::print(std::ostream& os) const {
70+
os << "LinearAlgebraTorch[device=" << device_ << "]";
9071
}
9172

9273

src/eckit/linalg/detail/Torch.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#pragma once
1313

14-
#include <string>
14+
#include <iosfwd>
1515

1616
#include "eckit/linalg/types.h"
1717

@@ -21,11 +21,31 @@
2121
namespace eckit::linalg::detail {
2222

2323

24-
torch::DeviceType get_torch_device(const std::string&);
25-
torch::Tensor torch_tensor_transpose(const torch::Tensor&);
26-
torch::Tensor make_torch_dense_tensor(const Matrix&, torch::DeviceType);
27-
torch::Tensor make_torch_dense_tensor(const Vector&, torch::DeviceType);
28-
torch::Tensor make_torch_sparse_csr(const SparseMatrix&, torch::DeviceType);
24+
/**
25+
* @brief Torch tensor creation and device management for linear algebra backends.
26+
*
27+
* Copies data host to/from device per operation. Transfer overhead may outweigh accelerator device gains for
28+
* small/frequent operations; best suited for large matrices where compute dominates.
29+
*/
30+
class Torch {
31+
protected:
32+
33+
explicit Torch(torch::DeviceType device, torch::ScalarType scalar) : device_(device), scalar_(scalar) {}
34+
35+
torch::Tensor tensor_transpose(const torch::Tensor&) const;
36+
torch::Tensor tensor_to_host(const torch::Tensor&) const;
37+
38+
torch::Tensor make_dense_tensor(const Matrix&) const;
39+
torch::Tensor make_dense_tensor(const Vector&) const;
40+
torch::Tensor make_sparse_csr_tensor(const SparseMatrix&) const;
41+
42+
void print(std::ostream&) const;
43+
44+
private:
45+
46+
const torch::DeviceType device_;
47+
const torch::ScalarType scalar_;
48+
};
2949

3050

3151
} // namespace eckit::linalg::detail

src/eckit/linalg/sparse/LinearAlgebraTorch.cc

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,26 @@
1212
#include "eckit/linalg/sparse/LinearAlgebraTorch.h"
1313

1414
#include <cstring>
15-
#include <ostream>
1615

1716
#include "eckit/exception/Exceptions.h"
1817
#include "eckit/linalg/Matrix.h"
1918
#include "eckit/linalg/SparseMatrix.h"
2019
#include "eckit/linalg/Vector.h"
21-
#include "eckit/linalg/sparse/LinearAlgebraGeneric.h"
22-
2320
#include "eckit/linalg/detail/Torch.h"
21+
#include "eckit/linalg/sparse/LinearAlgebraGeneric.h"
2422

2523

2624
namespace eckit::linalg::sparse {
2725

2826

29-
static const LinearAlgebraTorch LA_TORCH_CPU_1("torch");
30-
static const LinearAlgebraTorch LA_TORCH_CPU_2("torch-cpu");
31-
static const LinearAlgebraTorch LA_TORCH_CUDA("torch-cuda");
32-
static const LinearAlgebraTorch LA_TORCH_HIP("torch-hip");
33-
static const LinearAlgebraTorch LA_TORCH_MPS("torch-mps");
34-
static const LinearAlgebraTorch LA_TORCH_XPU("torch-xpu");
35-
static const LinearAlgebraTorch LA_TORCH_XLA("torch-xla");
36-
static const LinearAlgebraTorch LA_TORCH_META("torch-meta");
37-
38-
39-
using detail::get_torch_device;
40-
using detail::make_torch_dense_tensor;
41-
using detail::make_torch_sparse_csr;
42-
using detail::torch_tensor_transpose;
27+
static const LinearAlgebraTorch LA_TORCH_CPU_1("torch", torch::DeviceType::CPU);
28+
static const LinearAlgebraTorch LA_TORCH_CPU_2("torch-cpu", torch::DeviceType::CPU);
29+
static const LinearAlgebraTorch LA_TORCH_CUDA("torch-cuda", torch::DeviceType::CUDA);
30+
static const LinearAlgebraTorch LA_TORCH_HIP("torch-hip", torch::DeviceType::HIP);
31+
// static const LinearAlgebraTorch LA_TORCH_MPS("torch-mps", torch::DeviceType::MPS);
32+
static const LinearAlgebraTorch LA_TORCH_XPU("torch-xpu", torch::DeviceType::XPU);
33+
static const LinearAlgebraTorch LA_TORCH_XLA("torch-xla", torch::DeviceType::XLA);
34+
static const LinearAlgebraTorch LA_TORCH_META("torch-meta", torch::DeviceType::Meta);
4335

4436

4537
void LinearAlgebraTorch::spmv(const SparseMatrix& A, const Vector& x, Vector& y) const {
@@ -48,17 +40,10 @@ void LinearAlgebraTorch::spmv(const SparseMatrix& A, const Vector& x, Vector& y)
4840
ASSERT(Ni == y.rows());
4941
ASSERT(Nj == x.rows());
5042

51-
// Note: This implementation copies data to GPU memory for each operation and immediately
52-
// copies the result back to CPU. This data transfer overhead can be significant and may
53-
// negate the performance benefits of GPU computation for small matrices or frequent operations.
54-
// GPU acceleration is most beneficial for large matrices where computation time dominates
55-
// transfer overhead. For optimal performance, consider keeping data on GPU across multiple
56-
// operations rather than transferring for each call.
57-
5843
// multiplication
59-
auto A_tensor = make_torch_sparse_csr(A, get_torch_device(name()));
60-
auto x_tensor = make_torch_dense_tensor(x, get_torch_device(name()));
61-
auto y_tensor = torch::matmul(A_tensor, x_tensor).to(torch::kCPU).contiguous();
44+
auto A_tensor = make_sparse_csr_tensor(A);
45+
auto x_tensor = make_dense_tensor(x);
46+
auto y_tensor = tensor_to_host(torch::matmul(A_tensor, x_tensor));
6247

6348
// assignment
6449
std::memcpy(y.data(), y_tensor.data_ptr<Scalar>(), Ni * sizeof(Scalar));
@@ -73,17 +58,10 @@ void LinearAlgebraTorch::spmm(const SparseMatrix& A, const Matrix& X, Matrix& Y)
7358
ASSERT(Nj == X.rows());
7459
ASSERT(Nk == Y.cols());
7560

76-
// Note: This implementation copies data to GPU memory for each operation and immediately
77-
// copies the result back to CPU. This data transfer overhead can be significant and may
78-
// negate the performance benefits of GPU computation for small matrices or frequent operations.
79-
// GPU acceleration is most beneficial for large matrices where computation time dominates
80-
// transfer overhead. For optimal performance, consider keeping data on GPU across multiple
81-
// operations rather than transferring for each call.
82-
8361
// multiplication and conversion from column-major to row-major (and back)
84-
auto A_tensor = make_torch_sparse_csr(A, get_torch_device(name()));
85-
auto X_tensor = make_torch_dense_tensor(X, get_torch_device(name()));
86-
auto Y_tensor = torch_tensor_transpose(torch::matmul(A_tensor, X_tensor)).to(torch::kCPU).contiguous();
62+
auto A_tensor = make_sparse_csr_tensor(A);
63+
auto X_tensor = make_dense_tensor(X);
64+
auto Y_tensor = tensor_transpose(tensor_to_host(torch::matmul(A_tensor, X_tensor)));
8765

8866
// assignment
8967
std::memcpy(Y.data(), Y_tensor.data_ptr<Scalar>(), Y.size() * sizeof(Scalar));
@@ -96,9 +74,4 @@ void LinearAlgebraTorch::dsptd(const Vector& x, const SparseMatrix& A, const Vec
9674
}
9775

9876

99-
void LinearAlgebraTorch::print(std::ostream& out) const {
100-
out << "LinearAlgebraTorch[]";
101-
}
102-
103-
10477
} // namespace eckit::linalg::sparse

src/eckit/linalg/sparse/LinearAlgebraTorch.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@
1212
#pragma once
1313

1414
#include "eckit/linalg/LinearAlgebraSparse.h"
15+
#include "eckit/linalg/detail/Torch.h"
16+
1517

1618
namespace eckit::linalg::sparse {
1719

18-
struct LinearAlgebraTorch final : public LinearAlgebraSparse {
19-
LinearAlgebraTorch() = default;
20-
LinearAlgebraTorch(const std::string& name) : LinearAlgebraSparse(name) {}
20+
21+
struct LinearAlgebraTorch final : public LinearAlgebraSparse, detail::Torch {
22+
LinearAlgebraTorch(const std::string& name, torch::DeviceType device, torch::ScalarType scalar = torch::kFloat64) :
23+
LinearAlgebraSparse(name), Torch(device, scalar) {}
2124

2225
void spmv(const SparseMatrix&, const Vector&, Vector&) const override;
2326
void spmm(const SparseMatrix&, const Matrix&, Matrix&) const override;
2427
void dsptd(const Vector&, const SparseMatrix&, const Vector&, SparseMatrix&) const override;
25-
void print(std::ostream&) const override;
28+
void print(std::ostream& os) const override { Torch::print(os); }
2629
};
2730

31+
2832
} // namespace eckit::linalg::sparse

0 commit comments

Comments
 (0)