1010#include " mlir/IR/Builders.h"
1111#include " mlir/IR/DialectImplementation.h"
1212#include " llvm/ADT/TypeSwitch.h"
13+ #include < numeric>
1314
1415namespace mlir {
1516namespace xegpu {
@@ -336,32 +337,30 @@ LogicalResult TensorDescType::verify(
336337// [n_distribution_units, lane_data_size]
337338FailureOr<VectorType> TensorDescType::getDistributedVectorType () {
338339 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout ());
339- // If no layout is provided, tensor desc is not used in SIMT mode.
340- if (!layout)
340+ // It only works for subgroup level layout, which only has lane_layout
341+ // and lane_data, and is to distribute a SIMD code into SIMT code.
342+ if (!layout || !layout.isSgLayout ())
341343 return failure ();
342344
343345 SmallVector<int64_t > laneData (layout.getLaneData ().asArrayRef ());
344346 SmallVector<int64_t > laneLayout (layout.getLaneLayout ().asArrayRef ());
345347 auto tdescShape = getShape ();
346348
347- auto laneDataSize = 1 , sgSize = 1 ;
348- for ( auto [laneDim, laneDataDim] : llvm::zip_equal ( laneLayout, laneData)) {
349- laneDataSize *= laneDataDim;
350- sgSize *= laneDim;
351- }
349+ // compute sgSize by multiply elements of laneLayout
350+ // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
351+ // e.g. for 1D layout, sgSize = laneLayout[0]
352+ auto sgSize = std::accumulate (laneLayout. begin (), laneLayout. end (), 1 ,
353+ std::multiplies< int64_t >());
352354
353355 // Case 1: regular loads/stores
354356 auto scatterAttr = getEncodingAsScatterTensorDescAttr ();
355357 if (scatterAttr) {
356358 auto chunkSize = scatterAttr.getChunkSize ().getInt ();
357359 // Verify if the first dimension of the tensor descriptor shape is
358360 // distributable.
359- assert (tdescShape[0 ] % ( laneLayout[0 ]) == 0 &&
361+ assert (tdescShape[0 ] == laneLayout[0 ] &&
360362 " tensor descriptor shape is not distributable" );
361- if (chunkSize > 1 )
362- return VectorType::get ({chunkSize / laneDataSize, laneDataSize},
363- getElementType ());
364- return VectorType::get ({laneDataSize}, getElementType ());
363+ return VectorType::get ({chunkSize}, getElementType ());
365364 }
366365
367366 // Case 2: block loads/stores
@@ -376,8 +375,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
376375 // tensorSize must be adjusted for array_length.
377376 tensorSize *= getArrayLength ();
378377
379- return VectorType::get ({tensorSize / (sgSize * laneDataSize), laneDataSize},
380- getElementType ());
378+ return VectorType::get ({tensorSize / sgSize}, getElementType ());
381379}
382380
383381} // namespace xegpu
0 commit comments