Skip to content

Commit be0e0c8

Browse files
add a test to compare the accuracy of both amax implementations
1 parent 91249cc commit be0e0c8

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/cpp/operator/test_cast_current_scaling.cu

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,41 @@ TEST_P(CastCSTestSuite, TestCastCS) {
196196
}
197197

198198

199+
TEST(AmaxConsistencyTest, AtomicVsWorkspace) {
200+
using namespace transformer_engine;
201+
using namespace test;
202+
203+
std::vector<size_t> shape{256, 1024};
204+
const size_t N = product(shape);
205+
206+
// Input: FP32, Output: FP8 (E4M3) with per-tensor scaling
207+
Tensor input("input", shape, DType::kFloat32);
208+
Tensor out_atomic("out_atomic", shape, DType::kFloat8E4M3, true, false);
209+
Tensor out_ws("out_ws", shape, DType::kFloat8E4M3, true, false);
210+
211+
fillUniform(&input);
212+
213+
// Path 1: atomic-based amax (no workspace)
214+
nvte_compute_amax(input.data(), out_atomic.data(), 0);
215+
216+
// Path 2: two-stage amax using workspace
217+
// Use a workspace capacity >= number of blocks
218+
std::vector<size_t> ws_shape{N};
219+
Tensor workspace("workspace", ws_shape, DType::kFloat32);
220+
nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0);
221+
222+
cudaDeviceSynchronize();
223+
auto err = cudaGetLastError();
224+
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
225+
226+
// Compare the resulting amax values
227+
float amax_atomic = out_atomic.amax();
228+
float amax_ws = out_ws.amax();
229+
230+
compareResults("amax_consistency", amax_atomic, amax_ws, /*atol=*/0.0f, /*rtol=*/0.0f);
231+
}
232+
233+
199234

200235
INSTANTIATE_TEST_SUITE_P(
201236
OperatorTest,

0 commit comments

Comments
 (0)