Skip to content

Commit 2b8fc55

Browse files
authored
Enable RunMatMulTest all test cases support FP16 (#22440)
### Description <!-- Describe your changes. --> ### Motivation and Context increase FP16 test coverage for all related EPs
1 parent af00a20 commit 2b8fc55

File tree

1 file changed

+36
-52
lines changed

1 file changed

+36
-52
lines changed

onnxruntime/test/providers/cpu/math/matmul_test.cc

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -38,128 +38,125 @@ template <typename T>
3838
std::vector<MatMulTestData<T>> GenerateTestCases() {
3939
std::vector<MatMulTestData<T>> test_cases;
4040

41+
auto real_expected_vals = [](const std::vector<int32_t>& expected_vals) {
42+
if constexpr (std::is_same_v<T, int32_t>) {
43+
return expected_vals;
44+
} else if constexpr (std::is_same_v<T, MLFloat16>) {
45+
std::vector<MLFloat16> expected_vals_fp16(expected_vals.size());
46+
std::transform(expected_vals.begin(), expected_vals.end(), expected_vals_fp16.begin(),
47+
[](int32_t num) { return MLFloat16(float(num)); });
48+
return expected_vals_fp16;
49+
} else {
50+
std::vector<T> real_expected_vals(expected_vals.size());
51+
std::transform(expected_vals.begin(), expected_vals.end(), real_expected_vals.begin(),
52+
[](int32_t num) { return static_cast<T>(num); });
53+
return real_expected_vals;
54+
}
55+
};
56+
4157
test_cases.push_back(
4258
{"test padding and broadcast A > B",
4359
{3, 1, 1, 2},
4460
{2, 2, 2},
4561
{3, 2, 1, 2},
46-
{2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}});
62+
real_expected_vals({2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55})});
4763

4864
test_cases.push_back(
4965
{"test padding and broadcast B > A",
5066
{2, 3, 2},
5167
{3, 2, 2, 1},
5268
{3, 2, 3, 1},
53-
{1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}});
69+
real_expected_vals({1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221})});
5470

5571
test_cases.push_back(
5672
{"test left 1D",
5773
{2},
5874
{3, 2, 1},
5975
{3, 1},
60-
{1, 3, 5}});
76+
real_expected_vals({1, 3, 5})});
6177

6278
test_cases.push_back(
6379
{"test right 1D",
6480
{3, 1, 2},
6581
{2},
6682
{3, 1},
67-
{1, 3, 5}});
83+
real_expected_vals({1, 3, 5})});
6884

6985
test_cases.push_back(
7086
{"test left 1D right 2D",
7187
{2},
7288
{2, 3},
7389
{3},
74-
{3, 4, 5}});
90+
real_expected_vals({3, 4, 5})});
7591

7692
test_cases.push_back(
7793
{"test scalar output",
7894
{3},
7995
{3},
8096
{},
81-
{5}});
97+
real_expected_vals({5})});
8298

8399
test_cases.push_back(
84100
{"test 2D",
85101
{3, 4},
86102
{4, 3},
87103
{3, 3},
88-
{42, 48, 54, 114, 136, 158, 186, 224, 262}});
104+
real_expected_vals({42, 48, 54, 114, 136, 158, 186, 224, 262})});
89105

90106
test_cases.push_back(
91107
{"test 2D special",
92108
{2, 2, 3},
93109
{3, 4},
94110
{2, 2, 4},
95-
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
111+
real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})});
96112

97113
test_cases.push_back(
98114
{"test 2D special 2",
99115
{2, 2, 3},
100116
{1, 3, 4},
101117
{2, 2, 4},
102-
{20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}});
118+
real_expected_vals({20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218})});
103119

104120
test_cases.push_back(
105121
{"test 2D special 3",
106122
{2, 6},
107123
{1, 1, 6, 1},
108124
{1, 1, 2, 1},
109-
{55, 145}});
125+
real_expected_vals({55, 145})});
110126

111127
test_cases.push_back(
112128
{"test 2D empty input",
113129
{3, 4},
114130
{4, 0},
115131
{3, 0},
116-
{}});
132+
real_expected_vals({})});
117133

118134
test_cases.push_back(
119135
{"test 3D batch",
120136
{3, 1, 3},
121137
{3, 3, 2},
122138
{3, 1, 2},
123-
{
139+
real_expected_vals({
124140
// clang-format off
125141
10, 13,
126142
100, 112,
127143
298, 319,
128144
// clang-format on
129-
}});
145+
})});
130146

131147
test_cases.push_back(
132148
{"test 4D batch",
133149
{2, 2, 1, 3},
134150
{2, 2, 3, 2},
135151
{2, 2, 1, 2},
136-
{
152+
real_expected_vals({
137153
// clang-format off
138154
10, 13,
139155
100, 112,
140156
298, 319,
141157
604, 634,
142158
// clang-format on
143-
}});
144-
145-
return test_cases;
146-
}
147-
148-
template <>
149-
std::vector<MatMulTestData<MLFloat16>> GenerateTestCases() {
150-
std::vector<MatMulTestData<MLFloat16>> test_cases;
151-
152-
// test 2D expected_vals
153-
std::vector<int64_t> expected_vals = {42, 48, 54, 114, 136, 158, 186, 224, 262};
154-
std::vector<MLFloat16> expected_vals_fp16(expected_vals.size());
155-
std::transform(expected_vals.begin(), expected_vals.end(), expected_vals_fp16.begin(),
156-
[](int64_t num) { return MLFloat16(float(num)); });
157-
test_cases.push_back(
158-
{"test 2D MLfloat16",
159-
{3, 4},
160-
{4, 3},
161-
{3, 3},
162-
expected_vals_fp16});
159+
})});
163160

164161
return test_cases;
165162
}
@@ -209,19 +206,12 @@ TEST(MathOpTest, MatMulFloatType) {
209206
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
210207
}
211208
RunMatMulTest<float>(7, false, false);
212-
}
213-
214-
// To Test XNNPACK, Matrix B must be constant
215-
TEST(MathOpTest, MatMulFloatType_ConstantB) {
216-
// TODO: Unskip when fixed #41968513
217-
if (DefaultDmlExecutionProvider().get() != nullptr) {
218-
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
219-
}
209+
// Note. Xnnpack only supports matmul when Matrix B is constant
220210
RunMatMulTest<float>(7, false, true);
221211
}
222212

223213
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM) || defined(USE_XNNPACK)
224-
TEST(MathOpTest, MatMulFloat16_ConstantB) {
214+
TEST(MathOpTest, MatMulFloat16) {
225215
#ifdef USE_CUDA
226216
int min_cuda_architecture = 530;
227217
if (!HasCudaEnvironment(min_cuda_architecture)) {
@@ -233,22 +223,16 @@ TEST(MathOpTest, MatMulFloat16_ConstantB) {
233223
if (DefaultDmlExecutionProvider().get() != nullptr) {
234224
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
235225
}
236-
RunMatMulTest<MLFloat16>(7, false, true);
226+
RunMatMulTest<MLFloat16>(14, false, false);
227+
// Note. Xnnpack only supports matmul when Matrix B is constant
228+
RunMatMulTest<MLFloat16>(14, false, true);
237229
}
238230
#endif
239231

240232
TEST(MathOpTest, MatMulDoubleType) {
241233
RunMatMulTest<double>(7);
242234
}
243235

244-
TEST(MathOpTest, MatMulFloatTypeInitializer) {
245-
// TODO: Unskip when fixed #41968513
246-
if (DefaultDmlExecutionProvider().get() != nullptr) {
247-
GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)";
248-
}
249-
RunMatMulTest<float>(7, false, true);
250-
}
251-
252236
TEST(MathOpTest, MatMulInt32Type) {
253237
RunMatMulTest<int32_t>(9);
254238
}

0 commit comments

Comments
 (0)