@@ -196,6 +196,41 @@ TEST_P(CastCSTestSuite, TestCastCS) {
196196}
197197
198198
199+ TEST (AmaxConsistencyTest, AtomicVsWorkspace) {
200+ using namespace transformer_engine ;
201+ using namespace test ;
202+
203+ std::vector<size_t > shape{256 , 1024 };
204+ const size_t N = product (shape);
205+
206+ // Input: FP32, Output: FP8 (E4M3) with per-tensor scaling
207+ Tensor input (" input" , shape, DType::kFloat32 );
208+ Tensor out_atomic (" out_atomic" , shape, DType::kFloat8E4M3 , true , false );
209+ Tensor out_ws (" out_ws" , shape, DType::kFloat8E4M3 , true , false );
210+
211+ fillUniform (&input);
212+
213+ // Path 1: atomic-based amax (no workspace)
214+ nvte_compute_amax (input.data (), out_atomic.data (), 0 );
215+
216+ // Path 2: two-stage amax using workspace
217+ // Use a workspace capacity >= number of blocks
218+ std::vector<size_t > ws_shape{N};
219+ Tensor workspace (" workspace" , ws_shape, DType::kFloat32 );
220+ nvte_compute_amax_with_workspace (input.data (), out_ws.data (), workspace.data (), 0 );
221+
222+ cudaDeviceSynchronize ();
223+ auto err = cudaGetLastError ();
224+ ASSERT_EQ (err, cudaSuccess) << cudaGetErrorString (err);
225+
226+ // Compare the resulting amax values
227+ float amax_atomic = out_atomic.amax ();
228+ float amax_ws = out_ws.amax ();
229+
230+ compareResults (" amax_consistency" , amax_atomic, amax_ws, /* atol=*/ 0 .0f , /* rtol=*/ 0 .0f );
231+ }
232+
233+
199234
200235INSTANTIATE_TEST_SUITE_P (
201236 OperatorTest,
0 commit comments