1010#include " mlir/IR/Builders.h"
1111#include " mlir/IR/DialectImplementation.h"
1212#include " llvm/ADT/TypeSwitch.h"
13+ #include < numeric>
1314
1415namespace mlir {
1516namespace xegpu {
@@ -319,49 +320,48 @@ LogicalResult TensorDescType::verify(
319320// ---------------------------------------------------------------------
320321// Case 1: Regular loads/stores.
321322// ---------------------------------------------------------------------
322- // Distributed vector shape must be:
323- // [chunk_size / lane_data_size, lane_data_size ]
324- // If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
325- // [lane_data_size ]
323+ // The following conditions must be met :
324+ // * tensor_desc[0] == lane_layout[0 ]
325+ // Distributed vector is a 1D vector with shape:
326+ // [chunk_size ]
326327// ---------------------------------------------------------------------
327328// Case 2: Block loads/stores
328329// ---------------------------------------------------------------------
329330// Additional definitions:
330331// tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
331332// n_distribution_units = tensor_size / distribution_unit_size
333+ // fragment_size = n_distribution_units * lane_data_size
332334// Given above definitions, the following conditions must be met:
333335// * tensor_desc[0] % (lane_layout[0] × lane_data[0]) == 0
334336// * tensor_desc[1] % (lane_layout[1] × lane_data[1]) == 0
335- // Distributed vector shape must be :
336- // [n_distribution_units, lane_data_size ]
337+ // Distributed vector is a 1D vector with shape :
338+ // [fragment_size ]
337339FailureOr<VectorType> TensorDescType::getDistributedVectorType () {
338340 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout ());
339- // If no layout is provided, tensor desc is not used in SIMT mode.
340- if (!layout)
341+ // It only works for subgroup level layout, which only has lane_layout
342+ // and lane_data, and is to distribute a SIMD code into SIMT code.
343+ if (!layout || !layout.isSgLayout ())
341344 return failure ();
342345
343346 SmallVector<int64_t > laneData (layout.getLaneData ().asArrayRef ());
344347 SmallVector<int64_t > laneLayout (layout.getLaneLayout ().asArrayRef ());
345348 auto tdescShape = getShape ();
346349
347- auto laneDataSize = 1 , sgSize = 1 ;
348- for ( auto [laneDim, laneDataDim] : llvm::zip_equal ( laneLayout, laneData)) {
349- laneDataSize *= laneDataDim;
350- sgSize *= laneDim;
351- }
350+ // compute sgSize by multiply elements of laneLayout
351+ // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
352+ // e.g. for 1D layout, sgSize = laneLayout[0]
353+ auto sgSize = std::accumulate (laneLayout. begin (), laneLayout. end (), 1 ,
354+ std::multiplies< int64_t >());
352355
353356 // Case 1: regular loads/stores
354357 auto scatterAttr = getEncodingAsScatterTensorDescAttr ();
355358 if (scatterAttr) {
356359 auto chunkSize = scatterAttr.getChunkSize ().getInt ();
357360 // Verify if the first dimension of the tensor descriptor shape is
358361 // distributable.
359- assert (tdescShape[0 ] % ( laneLayout[0 ]) == 0 &&
362+ assert (tdescShape[0 ] == laneLayout[0 ] &&
360363 " tensor descriptor shape is not distributable" );
361- if (chunkSize > 1 )
362- return VectorType::get ({chunkSize / laneDataSize, laneDataSize},
363- getElementType ());
364- return VectorType::get ({laneDataSize}, getElementType ());
364+ return VectorType::get ({chunkSize}, getElementType ());
365365 }
366366
367367 // Case 2: block loads/stores
@@ -376,8 +376,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
376376 // tensorSize must be adjusted for array_length.
377377 tensorSize *= getArrayLength ();
378378
379- return VectorType::get ({tensorSize / (sgSize * laneDataSize), laneDataSize},
380- getElementType ());
379+ return VectorType::get ({tensorSize / sgSize}, getElementType ());
381380}
382381
383382} // namespace xegpu
0 commit comments