-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathview_inspectors.hpp
More file actions
111 lines (90 loc) · 3.23 KB
/
view_inspectors.hpp
File metadata and controls
111 lines (90 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once
#include <optional>
#include <spblas/detail/concepts.hpp>
#include <spblas/views/inspectors.hpp>
namespace spblas {
namespace __detail {
// Does this tensor view have a base?
template <typename T>
concept has_base = view<T> && requires(T& t) {
{ t.base() } -> tensor;
};
// Inspect a tensor: does it have a scaling factor? If so, compute it.
// Returns an empty optional if no scaling factor OR returns an optional
// with the product of all the scaling factors.
template <tensor T>
auto get_scaling_factor(T&& t) {
if constexpr (has_base<T>) {
auto base_scaling_factor = get_scaling_factor(t.base());
if constexpr (is_scaled_view_v<T>) {
auto scaling_factor = t.alpha();
using scaling_factor_type =
decltype(scaling_factor * base_scaling_factor.value());
if (base_scaling_factor.has_value()) {
return std::optional<scaling_factor_type>(scaling_factor *
base_scaling_factor.value());
} else {
return std::optional<scaling_factor_type>(scaling_factor);
}
} else {
return base_scaling_factor;
}
} else {
if constexpr (is_scaled_view_v<T>) {
return std::optional(t.alpha());
} else {
return std::optional<tensor_scalar_t<T>>{};
}
}
}
// Get scaling factors of t and u, returning:
// 1) empty optional, if no scaling factor in either
// 2) scaling factor of t OR u, if only one has a scaling factor
// 3) product of scaling factor of t and u, if both have a scaling factor.
template <tensor T, tensor U>
auto get_scaling_factor(T&& t, U&& u) {
auto t_scaling_factor = get_scaling_factor(t);
auto u_scaling_factor = get_scaling_factor(u);
using scalar_type = decltype(std::declval<typename std::remove_cvref_t<
decltype(t_scaling_factor)>::value_type>() *
std::declval<typename std::remove_cvref_t<
decltype(u_scaling_factor)>::value_type>());
if (t_scaling_factor.has_value()) {
if (u_scaling_factor.has_value()) {
return std::optional<scalar_type>(t_scaling_factor.value() *
u_scaling_factor.value());
} else {
return std::optional<scalar_type>(t_scaling_factor);
}
} else if (u_scaling_factor.has_value()) {
return std::optional<scalar_type>(u_scaling_factor);
} else {
return std::optional<scalar_type>{};
}
}
template <tensor T>
bool has_scaling_factor(T&& t) {
return get_scaling_factor(t).has_value();
}
template <tensor T>
auto get_ultimate_base(T&& t) {
if constexpr (has_base<T>) {
return get_ultimate_base(t.base());
} else {
return t;
}
}
template <typename T>
using ultimate_base_type_t = decltype(get_ultimate_base(std::declval<T>()));
template <typename T>
concept has_csr_base = is_csr_view_v<ultimate_base_type_t<T>>;
template <typename T>
concept has_csc_base = is_csc_view_v<ultimate_base_type_t<T>>;
template <typename T>
concept has_mdspan_matrix_base =
is_matrix_instantiation_of_mdspan_v<ultimate_base_type_t<T>>;
template <typename T>
concept has_contiguous_range_base =
spblas::__ranges::contiguous_range<ultimate_base_type_t<T>>;
} // namespace __detail
} // namespace spblas