1212#include " mlir/IR/DialectImplementation.h"
1313#include " llvm/ADT/SmallVector.h"
1414#include " llvm/ADT/TypeSwitch.h"
15+ #include < numeric>
1516
1617namespace mlir {
1718namespace xegpu {
@@ -338,32 +339,30 @@ LogicalResult TensorDescType::verify(
338339// [n_distribution_units, lane_data_size]
339340FailureOr<VectorType> TensorDescType::getDistributedVectorType () {
340341 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout ());
341- // If no layout is provided, tensor desc is not used in SIMT mode.
342- if (!layout)
342+ // It only works for subgroup level layout, which only has lane_layout
343+ // and lane_data, and is to distribute a SIMD code into SIMT code.
344+ if (!layout || !layout.isSgLayout ())
343345 return failure ();
344346
345347 SmallVector<int64_t > laneData (layout.getLaneData ().asArrayRef ());
346348 SmallVector<int64_t > laneLayout (layout.getLaneLayout ().asArrayRef ());
347349 auto tdescShape = getShape ();
348350
349- auto laneDataSize = 1 , sgSize = 1 ;
350- for ( auto [laneDim, laneDataDim] : llvm::zip_equal ( laneLayout, laneData)) {
351- laneDataSize *= laneDataDim;
352- sgSize *= laneDim;
353- }
351+ // compute sgSize by multiply elements of laneLayout
352+ // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
353+ // e.g. for 1D layout, sgSize = laneLayout[0]
354+ auto sgSize = std::accumulate (laneLayout. begin (), laneLayout. end (), 1 ,
355+ std::multiplies< int64_t >());
354356
355357 // Case 1: regular loads/stores
356358 auto scatterAttr = getEncodingAsScatterTensorDescAttr ();
357359 if (scatterAttr) {
358360 auto chunkSize = scatterAttr.getChunkSize ().getInt ();
359361 // Verify if the first dimension of the tensor descriptor shape is
360362 // distributable.
361- assert (tdescShape[0 ] % ( laneLayout[0 ]) == 0 &&
363+ assert (tdescShape[0 ] == laneLayout[0 ] &&
362364 " tensor descriptor shape is not distributable" );
363- if (chunkSize > 1 )
364- return VectorType::get ({chunkSize / laneDataSize, laneDataSize},
365- getElementType ());
366- return VectorType::get ({laneDataSize}, getElementType ());
365+ return VectorType::get ({chunkSize}, getElementType ());
367366 }
368367
369368 // Case 2: block loads/stores
@@ -378,12 +377,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
378377 // tensorSize must be adjusted for array_length.
379378 tensorSize *= getArrayLength ();
380379
381- if (layout.getRank () == 1 ) {
382- return VectorType::get ({tensorSize / sgSize}, getElementType ());
383- }
384-
385- return VectorType::get ({tensorSize / (sgSize * laneDataSize), laneDataSize},
386- getElementType ());
380+ return VectorType::get ({tensorSize / sgSize}, getElementType ());
387381}
388382
389383} // namespace xegpu
0 commit comments