Skip to content

Commit 115dc22

Browse files
committed
Re-write allocator -> hip_allocator
1 parent f3f93d0 commit 115dc22

File tree

7 files changed

+155
-171
lines changed

7 files changed

+155
-171
lines changed

include/spblas/allocator.hpp

Lines changed: 0 additions & 18 deletions
This file was deleted.

include/spblas/array.hpp

Lines changed: 0 additions & 58 deletions
This file was deleted.

include/spblas/vendor/rocsparse/allocator.hpp

Lines changed: 0 additions & 48 deletions
This file was deleted.

include/spblas/vendor/rocsparse/exception.hpp

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,79 @@
55
#include <stdexcept>
66
#include <string>
77

8-
namespace spblas::detail {
8+
namespace spblas {
99

10-
// throw an exception if the hipError_t is not hipSuccess.
11-
void throw_if_error(hipError_t error_code) {
10+
namespace __rocsparse {
11+
12+
// Throw an exception if the hipError_t is not hipSuccess.
13+
void throw_if_error(hipError_t error_code, std::string prefix = "") {
1214
if (error_code == hipSuccess) {
1315
return;
1416
}
1517
std::string name = hipGetErrorName(error_code);
1618
std::string message = hipGetErrorString(error_code);
17-
throw std::runtime_error(name + ":" + message);
19+
throw std::runtime_error(prefix + "HIP encountered an error " + name +
20+
": \"" + message + "\"");
1821
}
1922

20-
// throw an exception if the rocsparse_status is not rocsparse_status_success.
23+
// Throw an exception if the rocsparse_status is not rocsparse_status_success.
2124
void throw_if_error(rocsparse_status error_code) {
22-
#define REGISTER_ROCSPARSE_ERROR(error_name) \
23-
if (error_code == error_name) { \
24-
throw std::runtime_error(#error_name); \
25-
}
26-
2725
if (error_code == rocsparse_status_success) {
2826
return;
27+
} else if (error_code == rocsparse_status_invalid_handle) {
28+
throw std::runtime_error(
29+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_handle\"");
30+
} else if (error_code == rocsparse_status_not_implemented) {
31+
throw std::runtime_error(
32+
"rocSPARSE encountered an error: \"rocsparse_status_not_implemented\"");
33+
} else if (error_code == rocsparse_status_invalid_pointer) {
34+
throw std::runtime_error(
35+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_pointer\"");
36+
} else if (error_code == rocsparse_status_invalid_size) {
37+
throw std::runtime_error(
38+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_size\"");
39+
} else if (error_code == rocsparse_status_memory_error) {
40+
throw std::runtime_error(
41+
"rocSPARSE encountered an error: \"rocsparse_status_memory_error\"");
42+
} else if (error_code == rocsparse_status_internal_error) {
43+
throw std::runtime_error(
44+
"rocSPARSE encountered an error: \"rocsparse_status_internal_error\"");
45+
} else if (error_code == rocsparse_status_invalid_value) {
46+
throw std::runtime_error(
47+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_value\"");
48+
} else if (error_code == rocsparse_status_arch_mismatch) {
49+
throw std::runtime_error(
50+
"rocSPARSE encountered an error: \"rocsparse_status_arch_mismatch\"");
51+
} else if (error_code == rocsparse_status_zero_pivot) {
52+
throw std::runtime_error(
53+
"rocSPARSE encountered an error: \"rocsparse_status_zero_pivot\"");
54+
} else if (error_code == rocsparse_status_not_initialized) {
55+
throw std::runtime_error(
56+
"rocSPARSE encountered an error: \"rocsparse_status_not_initialized\"");
57+
} else if (error_code == rocsparse_status_type_mismatch) {
58+
throw std::runtime_error(
59+
"rocSPARSE encountered an error: \"rocsparse_status_type_mismatch\"");
60+
} else if (error_code == rocsparse_status_type_mismatch) {
61+
throw std::runtime_error(
62+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_size\"");
63+
} else if (error_code == rocsparse_status_invalid_size) {
64+
throw std::runtime_error(
65+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_size\"");
66+
} else if (error_code == rocsparse_status_invalid_size) {
67+
throw std::runtime_error(
68+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_size\"");
69+
} else if (error_code == rocsparse_status_invalid_size) {
70+
throw std::runtime_error(
71+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_size\"");
72+
} else if (error_code == rocsparse_status_invalid_size) {
73+
throw std::runtime_error(
74+
"rocSPARSE encountered an error: \"rocsparse_status_invalid_size\"");
75+
} else {
76+
throw std::runtime_error(
77+
"rocSPARSE encountered an error: \"unknown error\"");
2978
}
30-
31-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_invalid_handle);
32-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_not_implemented);
33-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_invalid_pointer);
34-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_invalid_size);
35-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_memory_error);
36-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_internal_error);
37-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_invalid_value);
38-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_arch_mismatch);
39-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_zero_pivot);
40-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_not_initialized);
41-
REGISTER_ROCSPARSE_ERROR(rocsparse_status_type_mismatch);
42-
#undef REGISTER_ROCSPARSE_ERROR
43-
44-
throw std::runtime_error("Unknown error from rocsparse_status");
4579
}
4680

47-
} // namespace spblas::detail
81+
} // namespace __rocsparse
82+
83+
} // namespace spblas
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#pragma once
2+
3+
#include "exception.hpp"
4+
#include <hip/hip_runtime.h>
5+
#include <hip/hip_runtime_api.h>
6+
7+
namespace spblas {
8+
9+
namespace rocsparse {
10+
11+
template <typename T, std::size_t Alignment = 0>
12+
class hip_allocator {
13+
public:
14+
using value_type = T;
15+
using pointer = T*;
16+
using const_pointer = const T*;
17+
using reference = T&;
18+
using const_reference = const T&;
19+
using size_type = std::size_t;
20+
using difference_type = std::ptrdiff_t;
21+
22+
hip_allocator() noexcept {}
23+
hip_allocator(hipStream_t stream) noexcept : stream_(stream) {}
24+
25+
template <typename U>
26+
hip_allocator(const hip_allocator<U, Alignment>& other) noexcept
27+
: stream_(other.stream()) {}
28+
29+
hip_allocator(const hip_allocator&) = default;
30+
hip_allocator& operator=(const hip_allocator&) = default;
31+
~hip_allocator() = default;
32+
33+
using is_always_equal = std::false_type;
34+
35+
pointer allocate(std::size_t size) {
36+
void* ptr;
37+
hipError_t error = hipMallocAsync(&ptr, size * sizeof(T), stream());
38+
throw_if_failure(error);
39+
40+
return reinterpret_cast<T*>(ptr);
41+
}
42+
43+
void deallocate(pointer ptr, std::size_t n = 0) {
44+
if (ptr != nullptr) {
45+
hipError_t error = hipFreeAsync(ptr, stream());
46+
throw_if_failure(error);
47+
}
48+
}
49+
50+
bool operator==(const hip_allocator&) const = default;
51+
bool operator!=(const hip_allocator&) const = default;
52+
53+
template <typename U>
54+
struct rebind {
55+
using other = hip_allocator<U, Alignment>;
56+
};
57+
58+
hipStream_t stream() const noexcept {
59+
return stream_;
60+
}
61+
62+
private:
63+
void throw_if_failure(hipError_t error) {
64+
if (error != hipSuccess) {
65+
throw std::bad_alloc{};
66+
}
67+
}
68+
69+
hipStream_t stream_ = nullptr;
70+
};
71+
72+
} // namespace rocsparse
73+
74+
} // namespace spblas

0 commit comments

Comments
 (0)