Skip to content

Commit dab58c1

Browse files
committed
move activations out of common.hpp
1 parent 14d4ef6 commit dab58c1

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

sycl/test-e2e/Matrix/common.hpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
3636
big_matrix(T *data) : mat(data) {}
3737
};
3838

39-
enum class Activation {
40-
ReLU,
41-
Sigmoid,
42-
None,
43-
};
44-
4539
float make_fp32(bfloat16 x) {
4640
unsigned int y = *((int *)&x);
4741
y = y << 16;
@@ -169,8 +163,7 @@ void matrix_apply(unsigned int rows, unsigned int cols, T *mat, F op) {
169163
mat[i * cols + j] = op(mat[i * cols + j]);
170164
}
171165

172-
template <Activation act = Activation::None, typename T1, typename T2,
173-
bool exact = false>
166+
template <typename T1, typename T2, bool exact = false>
174167
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
175168
for (int i = 0; i < rows; i++) {
176169
for (int j = 0; j < cols; j++) {

sycl/test-e2e/Matrix/joint_matrix_activation_impl.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ constexpr size_t TM = 8;
1515
constexpr size_t TN = 16;
1616
constexpr size_t TK = 16;
1717

18+
enum class Activation {
19+
ReLU,
20+
Sigmoid,
21+
None,
22+
};
23+
1824
template <typename T> T ReLU(T x) { return sycl::max(static_cast<T>(0), x); }
1925

2026
template <typename T> T Sigmoid(T x) {
@@ -105,9 +111,7 @@ int main() {
105111
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
106112

107113
matrix_activation_copy<Activation::None>(MC, MA);
108-
bool res0 = matrix_compare(MATRIX_M, MATRIX_N, (bfloat16 *)A, (float *)C);
109-
bool res = matrix_compare<Activation::None>(MATRIX_M, MATRIX_N, (bfloat16 *)A,
110-
(float *)C);
114+
bool res = matrix_compare(MATRIX_M, MATRIX_N, (bfloat16 *)A, (float *)C);
111115
std::cout << (res ? "Copy passed" : "Copy failed") << std::endl;
112116

113117
matrix_activation_copy<Activation::ReLU>(MC, MA);

0 commit comments

Comments
 (0)