@@ -341,38 +341,44 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
341341
342342} // anonymous namespace
343343
344+ // clang-format off
344345// The layout example repeat_count=8, systolic_depth=8,
345346// execution_size=16 and operands_per_chan=2 for warp size 32.
346347// For A operand:
347- // systolic depth = 8
348- // <----------------------------------------------------->
349- // opsPerChan=2
350- // <--------->
351- // t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^
352- // t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
353- // t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
354- // t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8
355- // t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 |
356- // t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 |
357- // t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 |
358- // t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v
348+ // K = 16 (K = systolic depth * opsPerChan)
349+ // <---------------------------------------------------------------------------->
350+ // t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^
351+ // t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
352+ // t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
353+ // t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
354+ // t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count)
355+ // t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 |
356+ // t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 |
357+ // t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v
359358// In this case, the LinearLayout bases are:
360- // Register: {{0,1}, {4,0}}
361- // Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}}
359+ // Register: {{2,0}, {4,0}}
360+ // Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {2,0}}
361+ // clang-format on
362362std::vector<std::vector<int32_t >> DPASRegBasesA (int opsPerChannel,
363363 int repeatCount,
364364 int threadsPerWarp,
365365 int systolicDepth) {
366- int rowPerWarp = threadsPerWarp / systolicDepth;
367- int warpRepeats = repeatCount / rowPerWarp;
368366 std::vector<std::vector<int32_t >> regBases;
369367
370- for (int opc = 1 ; opc < opsPerChannel; opc *= 2 ) {
368+ // pack the value to i16 for scalar bit width <=16.
369+ assert ((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1 ) &&
370+ " invalid opsPerChannel number." );
371+ int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1 ;
372+ int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane;
373+ int rowsPerWarp = mlir::ceil<int >(threadsPerWarp, packedColNum);
374+ int warpRepeats = repeatCount / rowsPerWarp;
375+
376+ for (int opc = 1 ; opc < packedOpsPerLane; opc *= 2 ) {
371377 regBases.push_back ({0 , opc});
372378 }
373379
374380 for (int warp = 1 ; warp < warpRepeats; warp *= 2 ) {
375- regBases.push_back ({warp * rowPerWarp , 0 });
381+ regBases.push_back ({warp * rowsPerWarp , 0 });
376382 }
377383
378384 return regBases;
@@ -382,11 +388,17 @@ std::vector<std::vector<int32_t>>
382388DPASLaneBasesA (int opsPerChannel, int threadsPerWarp, int systolicDepth) {
383389 std::vector<std::vector<int32_t >> laneBases;
384390
385- for (int tid = 1 ; tid < systolicDepth; tid *= 2 ) {
386- laneBases.push_back ({0 , opsPerChannel * tid});
391+ // pack the value to i16 for scalar bit width <=16.
392+ assert ((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1 ) &&
393+ " invalid opsPerChannel number." );
394+ int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1 ;
395+ int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane;
396+
397+ for (int tid = 1 ; tid < packedColNum; tid *= 2 ) {
398+ laneBases.push_back ({0 , packedOpsPerLane * tid});
387399 }
388- for (int tid = systolicDepth ; tid < threadsPerWarp; tid *= 2 ) {
389- laneBases.push_back ({tid / systolicDepth , 0 });
400+ for (int tid = packedColNum ; tid < threadsPerWarp; tid *= 2 ) {
401+ laneBases.push_back ({tid / packedColNum , 0 });
390402 }
391403
392404 return laneBases;
@@ -602,8 +614,7 @@ std::optional<LinearLayout>
602614dotOperandDpasToLinearLayout (DotOperandEncodingAttr dotDpasLayout,
603615 ArrayRef<int64_t > shape) {
604616 auto dpasLayout = cast<intel::DpasEncodingAttr>(dotDpasLayout.getParent ());
605- if (dotDpasLayout.getOpIdx () == 0 )
606- return std::nullopt ;
617+
607618 return DPAStoLinearLayout (shape, dpasLayout, dotDpasLayout.getOpIdx ());
608619}
609620
0 commit comments