@@ -368,6 +368,8 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) {
368368 auto dot2dOp1 = createDotOperand (1 , mfma2d, 4 );
369369 ASSERT_THAT (dot2dOp0.getWarpOrder (), mfma2d.getWarpOrder ());
370370 ASSERT_THAT (dot2dOp1.getWarpOrder (), mfma2d.getWarpOrder ());
371+ ASSERT_THAT (dot2dOp0.getThreadsPerWarp (), testing::ElementsAre (32u , 2u ));
372+ ASSERT_THAT (dot2dOp1.getThreadsPerWarp (), testing::ElementsAre (2u , 32u ));
371373
372374 auto tmfma2d = createTransposedMFMA (32 , 32 , {2 , 4 });
373375 auto tdot2dOp0 = createDotOperand (0 , tmfma2d, 4 );
@@ -380,12 +382,28 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) {
380382 auto dot3dOp1 = createDotOperand (1 , mfma3d, 4 );
381383 ASSERT_THAT (dot3dOp0.getWarpOrder (), mfma3d.getWarpOrder ());
382384 ASSERT_THAT (dot3dOp1.getWarpOrder (), mfma3d.getWarpOrder ());
385+ ASSERT_THAT (dot3dOp0.getThreadsPerWarp (), testing::ElementsAre (1u , 32u , 2u ));
386+ ASSERT_THAT (dot3dOp1.getThreadsPerWarp (), testing::ElementsAre (1u , 2u , 32u ));
383387
384388 auto tmfma3d = createTransposedMFMA (32 , 32 , {2 , 4 , 1 });
385389 auto tdot3dOp0 = createDotOperand (0 , tmfma3d, 4 );
386390 auto tdot3dOp1 = createDotOperand (1 , tmfma3d, 4 );
387391 ASSERT_THAT (tdot3dOp0.getWarpOrder (), tmfma3d.getWarpOrder ());
388392 ASSERT_THAT (tdot3dOp1.getWarpOrder (), tmfma3d.getWarpOrder ());
393+
394+ auto mfma16_2d = createMFMA (16 , 16 , {2 , 4 });
395+ auto dot16_2dOp0 = createDotOperand (0 , mfma16_2d, 4 );
396+ auto dot16_2dOp1 = createDotOperand (1 , mfma16_2d, 4 );
397+ ASSERT_THAT (dot16_2dOp0.getThreadsPerWarp (), testing::ElementsAre (16u , 4u ));
398+ ASSERT_THAT (dot16_2dOp1.getThreadsPerWarp (), testing::ElementsAre (4u , 16u ));
399+
400+ auto mfma16_3d = createMFMA (16 , 16 , {2 , 4 , 1 });
401+ auto dot16_3dOp0 = createDotOperand (0 , mfma16_3d, 4 );
402+ auto dot16_3dOp1 = createDotOperand (1 , mfma16_3d, 4 );
403+ ASSERT_THAT (dot16_3dOp0.getThreadsPerWarp (),
404+ testing::ElementsAre (1u , 16u , 4u ));
405+ ASSERT_THAT (dot16_3dOp1.getThreadsPerWarp (),
406+ testing::ElementsAre (1u , 4u , 16u ));
389407}
390408
391409TEST_F (AMDWmmaLayoutTest, wmmaV1) {
@@ -434,24 +452,36 @@ TEST_F(AMDWmmaLayoutTest, wmma_dot_op) {
434452 auto dot2dVer1Op1 = createDotOperand (1 , wmma2dVer1, 16 );
435453 ASSERT_THAT (dot2dVer1Op0.getWarpOrder (), wmma2dVer1.getWarpOrder ());
436454 ASSERT_THAT (dot2dVer1Op1.getWarpOrder (), wmma2dVer1.getWarpOrder ());
455+ ASSERT_THAT (dot2dVer1Op0.getThreadsPerWarp (), testing::ElementsAre (16u , 1u ));
456+ ASSERT_THAT (dot2dVer1Op1.getThreadsPerWarp (), testing::ElementsAre (1u , 16u ));
437457
438- auto wmma3dVer1 = createWMMAv1 ({2 , 4 });
458+ auto wmma3dVer1 = createWMMAv1 ({2 , 4 , 1 });
439459 auto dot3dVer1Op0 = createDotOperand (0 , wmma3dVer1, 16 );
440460 auto dot3dVer1Op1 = createDotOperand (1 , wmma3dVer1, 16 );
441461 ASSERT_THAT (dot3dVer1Op0.getWarpOrder (), wmma3dVer1.getWarpOrder ());
442462 ASSERT_THAT (dot3dVer1Op1.getWarpOrder (), wmma3dVer1.getWarpOrder ());
463+ ASSERT_THAT (dot3dVer1Op0.getThreadsPerWarp (),
464+ testing::ElementsAre (1 , 16u , 1u ));
465+ ASSERT_THAT (dot3dVer1Op1.getThreadsPerWarp (),
466+ testing::ElementsAre (1 , 1u , 16u ));
443467
444468 auto wmma2dVer2 = createWMMAv2 (false , {2 , 4 });
445469 auto dot2dVer2Op0 = createDotOperand (0 , wmma2dVer2, 16 );
446470 auto dot2dVer2Op1 = createDotOperand (1 , wmma2dVer2, 16 );
447471 ASSERT_THAT (dot2dVer2Op0.getWarpOrder (), wmma2dVer2.getWarpOrder ());
448472 ASSERT_THAT (dot2dVer2Op1.getWarpOrder (), wmma2dVer2.getWarpOrder ());
473+ ASSERT_THAT (dot2dVer2Op0.getThreadsPerWarp (), testing::ElementsAre (16u , 2u ));
474+ ASSERT_THAT (dot2dVer2Op1.getThreadsPerWarp (), testing::ElementsAre (2u , 16u ));
449475
450- auto wmma3dVer2 = createWMMAv2 (false , {2 , 4 });
476+ auto wmma3dVer2 = createWMMAv2 (false , {2 , 4 , 1 });
451477 auto dot3dVer2Op0 = createDotOperand (0 , wmma3dVer2, 16 );
452478 auto dot3dVer2Op1 = createDotOperand (1 , wmma3dVer2, 16 );
453479 ASSERT_THAT (dot3dVer2Op0.getWarpOrder (), wmma3dVer2.getWarpOrder ());
454480 ASSERT_THAT (dot3dVer2Op1.getWarpOrder (), wmma3dVer2.getWarpOrder ());
481+ ASSERT_THAT (dot3dVer2Op0.getThreadsPerWarp (),
482+ testing::ElementsAre (1 , 16u , 2u ));
483+ ASSERT_THAT (dot3dVer2Op1.getThreadsPerWarp (),
484+ testing::ElementsAre (1 , 2u , 16u ));
455485}
456486
457487class LinearEncodingTest : public ::testing::Test {
0 commit comments