@@ -312,87 +312,6 @@ LogicalResult TensorDescType::verify(
312312 return success ();
313313}
314314
315- // If tensor descriptor has a layout attribute it is used in SIMT mode.
316- // In this mode, the distributed vector shape is determined as follows:
317- // Definitions:
318- // lane_data_size = lane_data[0] × lane_data[1]
319- // subgroup_size = lane_layout[0] × lane_layout[1]
320- // distribution_unit_size = subgroup_size × lane_data_size
321- // ---------------------------------------------------------------------
322- // Case 1: Regular loads/stores.
323- // ---------------------------------------------------------------------
324- // The following conditions must be met:
325- // * tensor_desc[0] == lane_layout[0]
326- // Distributed vector is a 1D vector with shape:
327- // [chunk_size]
328- // ---------------------------------------------------------------------
329- // Case 2: Block loads/stores
330- // ---------------------------------------------------------------------
331- // Additional definitions:
332- // tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
333- // n_distribution_units = tensor_size / distribution_unit_size
334- // fragment_size = n_distribution_units * lane_data_size
335- // Given above definitions, the following conditions must be met:
336- // * tensor_desc[0] % (lane_layout[0] × lane_data[0]) == 0
337- // * tensor_desc[1] % (lane_layout[1] × lane_data[1]) == 0
338- // Distributed vector is a 1D vector with shape:
339- // [fragment_size]
340- FailureOr<VectorType> getDistributedVectorType (xegpu::TensorDescType tdescTy) {
341- auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout ());
342- // It only works for subgroup level layout, which only has lane_layout
343- // and lane_data, and is to distribute a SIMD code into SIMT code.
344- if (!layout || !layout.isSgLayout ())
345- return failure ();
346-
347- SmallVector<int64_t > laneData (layout.getLaneData ().asArrayRef ());
348- SmallVector<int64_t > laneLayout (layout.getLaneLayout ().asArrayRef ());
349- auto tdescShape = tdescTy.getShape ();
350- auto elementType = tdescTy.getElementType ();
351-
352- // compute sgSize by multiply elements of laneLayout
353- // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
354- // e.g. for 1D layout, sgSize = laneLayout[0]
355- auto sgSize = std::accumulate (laneLayout.begin (), laneLayout.end (), 1 ,
356- std::multiplies<int64_t >());
357-
358- // Case 1: regular loads/stores
359- auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr ();
360- if (scatterAttr) {
361- auto chunkSize = scatterAttr.getChunkSize ().getInt ();
362- // Verify if the first dimension of the tensor descriptor shape is
363- // distributable.
364- assert (tdescShape[0 ] == laneLayout[0 ] &&
365- " tensor descriptor shape is not distributable" );
366- return VectorType::get ({chunkSize}, elementType);
367- }
368-
369- // Case 2: block loads/stores
370- // Check if the tensor descriptor shape is distributable.
371- int64_t tensorSize = 1 ;
372- for (auto [tdescDim, laneDim, laneDataDim] :
373- llvm::zip_equal (tdescShape, laneLayout, laneData)) {
374- assert ((tdescDim % (laneDim * laneDataDim) == 0 ) &&
375- " tensor descriptor shape is not distributable" );
376- tensorSize *= tdescDim;
377- }
378- // tensorSize must be adjusted for array_length.
379- tensorSize *= tdescTy.getArrayLength ();
380-
381- return VectorType::get ({tensorSize / sgSize}, elementType);
382- }
383-
384- // Helper to get the distributed vector type for a given vector type according
385- // to a given LayoutAttr.
386- FailureOr<VectorType> getDistributedVectorType (VectorType originalType,
387- LayoutAttr layout) {
388- auto shape = originalType.getShape ();
389- auto helperTdescTy = xegpu::TensorDescType::get (
390- shape, originalType.getElementType (),
391- /* array_length=*/ 1 , /* boundary_check=*/ true ,
392- /* memory_space=*/ xegpu::MemorySpace::Global, layout);
393- return xegpu::getDistributedVectorType (helperTdescTy);
394- }
395-
396315} // namespace xegpu
397316} // namespace mlir
398317
0 commit comments