Skip to content

Commit d06477e

Browse files
committed
move work from old branch
1 parent ae6b4b2 commit d06477e

File tree

3 files changed

+319
-2
lines changed

3 files changed

+319
-2
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
409409
}
410410

411411
def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
412-
[AllTypesMatch<["TensorDesc", "result"]>]> {
412+
[Pure, AllTypesMatch<["TensorDesc", "result"]>]> {
413413
let summary = "It updates the offsets for the TensorDesc.";
414414
let description = [{The op updates the offset of the given TensorDesc.
415415
The offsets are relative offset to the current position in the number

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ class LayoutInfoPropagation
301301
ArrayRef<LayoutInfoLattice *> operands,
302302
ArrayRef<const LayoutInfoLattice *> results);
303303

304+
void visitPrefetchNdOp(xegpu::PrefetchNdOp prefetch,
305+
ArrayRef<LayoutInfoLattice *> operands,
306+
ArrayRef<const LayoutInfoLattice *> results);
307+
304308
void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction,
305309
ArrayRef<LayoutInfoLattice *> operands,
306310
ArrayRef<const LayoutInfoLattice *> results);
@@ -352,6 +356,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
352356
.Case<xegpu::UpdateNdOffsetOp>([&](auto updateNdOffsetOp) {
353357
visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results);
354358
})
359+
.Case<xegpu::PrefetchNdOp>([&](auto prefetchNdOp) {
360+
visitPrefetchNdOp(prefetchNdOp, operands, results);
361+
})
355362
// No need to propagate the layout to operands in CreateNdDescOp because
356363
// they are scalars (offsets, sizes, etc.).
357364
.Case<xegpu::CreateNdDescOp>([&](auto createNdDescOp) {})
@@ -381,6 +388,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
381388
return success();
382389
}
383390

391+
void LayoutInfoPropagation::visitPrefetchNdOp(
392+
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
393+
ArrayRef<const LayoutInfoLattice *> results) {
394+
// Here we assign the default layout to the tensor descriptor operand of
395+
// prefetch.
396+
auto tdescTy = prefetch.getTensorDescType();
397+
auto prefetchLayout = getDefaultLayoutInfo(
398+
VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
399+
// Propagate the layout to the source tensor descriptor.
400+
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
401+
}
402+
384403
void LayoutInfoPropagation::visitVectorMultiReductionOp(
385404
vector::MultiDimReductionOp reduction,
386405
ArrayRef<LayoutInfoLattice *> operands,
@@ -1412,6 +1431,174 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
14121431
}
14131432
};
14141433

1434+
/// Sink an update_nd_offset op feeding into yield op of an enclosing
1435+
/// `gpu.warp_execute_on_lane_0` region. The warp op will still contain the
1436+
/// original op that will not be used by the yield op (and should be cleaned
1437+
/// up later). The yield op will bypass the updateOp's arguments. The tensor
1438+
/// descriptor type is not distributed. Appropriate cast ops are inserted if
1439+
/// the distributed types does not match expected xegpu SIMT types.
1440+
/// Example:
1441+
/// ```
1442+
/// #lo0 = #xegpu.layout<wi_layout = [1, 8], wi_data = [1, 1]>
1443+
/// %r = gpu.warp_execute_on_lane_0(%laneid) ->
1444+
/// (!xegpu.tensor_desc<4x8xf32, #lo0>) {
1445+
/// ...
1446+
/// %update = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1447+
/// !xegpu.tensor_desc<4x8xf32, #lo0>
1448+
/// gpu.yield %update
1449+
/// }
1450+
/// ...
1451+
/// ```
1452+
/// To
1453+
/// ```
1454+
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4x1xf32>,
1455+
/// !xegpu.tensor_desc<4x8xf32, #lo0>) {
1456+
/// ...
1457+
/// %dead = xegpu.update_nd_offset %arg0, [%c32, %c16]:
1458+
/// !xegpu.tensor_desc<4x8xf32, #lo0> gpu.yield %dead, %arg0
1459+
/// gup.yield %dead, %arg0, %c32, %c16
1460+
/// }
1461+
/// %0 = xegpu.unrealized_conversion_cast %r#1: !xegpu.tensor_desc<4x8xf32,
1462+
/// #lo0> -> !xegpu.tensor_desc<4x8xf32>
1463+
/// %1 = xegpu.update_nd_offset %0, [%c32, %c16]:
1464+
/// !xegpu.tensor_desc<4x8xf32>
1465+
/// ...
1466+
/// ```
1467+
struct UpdateNdOffsetDistribution final : public gpu::WarpDistributionPattern {
1468+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1469+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1470+
PatternRewriter &rewriter) const override {
1471+
OpOperand *operand =
1472+
getWarpResult(subgroupOp, llvm::IsaPred<xegpu::UpdateNdOffsetOp>);
1473+
if (!operand)
1474+
return rewriter.notifyMatchFailure(
1475+
subgroupOp, "warp result is not a xegpu::UpdateNdOffset op");
1476+
auto updateOp = operand->get().getDefiningOp<xegpu::UpdateNdOffsetOp>();
1477+
unsigned operandIdx = operand->getOperandNumber();
1478+
auto newTensorDescTy = dropLayouts(updateOp.getTensorDescType());
1479+
1480+
SmallVector<Value, 3> newYieldValues;
1481+
SmallVector<Type, 3> newYieldTypes;
1482+
for (auto operand : updateOp->getOperands()) {
1483+
newYieldValues.push_back(operand);
1484+
if (isa<xegpu::TensorDescType>(operand.getType())) {
1485+
newYieldTypes.push_back(newTensorDescTy);
1486+
} else {
1487+
newYieldTypes.push_back(operand.getType());
1488+
}
1489+
}
1490+
SmallVector<size_t> newRetIndices;
1491+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1492+
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1493+
rewriter.setInsertionPointAfter(newWarpOp);
1494+
SmallVector<Value> newUpdateOperands;
1495+
for (auto i : newRetIndices) {
1496+
if (isa<xegpu::TensorDescType>(newWarpOp.getResult(i).getType())) {
1497+
newUpdateOperands.push_back(resolveDistributedTy(
1498+
newWarpOp.getResult(i), newTensorDescTy, rewriter));
1499+
} else {
1500+
newUpdateOperands.push_back(newWarpOp.getResult(i));
1501+
}
1502+
}
1503+
auto newUpdateOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
1504+
newWarpOp.getLoc(), newTensorDescTy, newUpdateOperands,
1505+
removeTemporaryLayoutAttributes(updateOp->getAttrs()));
1506+
Value distributedVal = newWarpOp.getResult(operandIdx);
1507+
rewriter.replaceAllUsesWith(distributedVal, newUpdateOp);
1508+
return success();
1509+
}
1510+
};
1511+
1512+
struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
1513+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1514+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1515+
PatternRewriter &rewriter) const override {
1516+
auto yield = cast<gpu::YieldOp>(
1517+
subgroupOp.getBodyRegion().getBlocks().begin()->getTerminator());
1518+
Operation *lastNode = yield->getPrevNode();
1519+
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
1520+
if (!prefetchOp)
1521+
return failure();
1522+
auto layout = prefetchOp.getTensorDescType().getLayoutAttr();
1523+
if (!layout)
1524+
return rewriter.notifyMatchFailure(
1525+
prefetchOp, "the source tensor descriptor lacks layout attribute");
1526+
1527+
SmallVector<Value, 1> newYieldValues = {prefetchOp.getTensorDesc()};
1528+
SmallVector<Type, 1> newYieldTypes = {prefetchOp.getTensorDescType()};
1529+
SmallVector<size_t> newRetIndices;
1530+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1531+
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1532+
1533+
auto newTensorDescTy = dropLayouts(prefetchOp.getTensorDescType());
1534+
rewriter.setInsertionPointAfter(newWarpOp);
1535+
SmallVector<Value> newPrefetchOperands = {resolveDistributedTy(
1536+
newWarpOp.getResult(newRetIndices[0]), newTensorDescTy, rewriter)};
1537+
rewriter.create<xegpu::PrefetchNdOp>(
1538+
newWarpOp.getLoc(), TypeRange{}, newPrefetchOperands,
1539+
removeTemporaryLayoutAttributes(prefetchOp->getAttrs()));
1540+
rewriter.eraseOp(prefetchOp);
1541+
return success();
1542+
}
1543+
};
1544+
1545+
/// Generic pattern for sinking a GPU index operations feeding into yield op
1546+
/// of an enclosing `gpu.warp_execute_on_lane_0` region. The original index op
1547+
/// becomes dead and an equivalent copy of the index op is created outside the
1548+
/// warp op.
1549+
/// Example:
1550+
/// ```
1551+
/// %r = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
1552+
/// ...
1553+
/// %index = gpu.block_id x : index
1554+
/// gpu.yield %index
1555+
/// }
1556+
/// ...
1557+
/// ```
1558+
/// To
1559+
/// ```
1560+
/// %r:2 = gpu.warp_execute_on_lane_0(%laneid) -> (index) {
1561+
/// ...
1562+
/// %dead = gpu.block_id x : index
1563+
/// gpu.yield %dead
1564+
/// }
1565+
/// %0 = gpu.block_id x : index
1566+
/// ...
1567+
/// ```
1568+
template <typename IndexOp>
1569+
struct GpuIndexOpDistribution final : public gpu::WarpDistributionPattern {
1570+
using gpu::WarpDistributionPattern::WarpDistributionPattern;
1571+
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
1572+
PatternRewriter &rewriter) const override {
1573+
auto operand = getWarpResult(subgroupOp, llvm::IsaPred<IndexOp>);
1574+
if (!operand)
1575+
return rewriter.notifyMatchFailure(subgroupOp,
1576+
"warp result is not a gpu index op");
1577+
auto indexOp = operand->template get().template getDefiningOp<IndexOp>();
1578+
unsigned operandIdx = operand->template getOperandNumber();
1579+
SmallVector<Value, 3> newYieldValues;
1580+
SmallVector<Type, 3> newYieldTypes;
1581+
for (auto operand : indexOp->template getOperands()) {
1582+
newYieldValues.push_back(operand);
1583+
newYieldTypes.push_back(operand.getType());
1584+
}
1585+
SmallVector<size_t> newRetIndices;
1586+
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1587+
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
1588+
rewriter.setInsertionPointAfter(newWarpOp);
1589+
SmallVector<Value> newIndexOperands;
1590+
for (auto i : newRetIndices) {
1591+
newIndexOperands.push_back(newWarpOp.getResult(i));
1592+
}
1593+
auto newIndexOp = rewriter.create<IndexOp>(
1594+
newWarpOp.getLoc(), newIndexOperands,
1595+
removeTemporaryLayoutAttributes(indexOp->template getAttrs()));
1596+
Value distributedVal = newWarpOp.getResult(operandIdx);
1597+
rewriter.replaceAllUsesWith(distributedVal, newIndexOp);
1598+
return success();
1599+
}
1600+
};
1601+
14151602
} // namespace
14161603

14171604
namespace {
@@ -1430,7 +1617,22 @@ struct XeGPUSubgroupDistributePass final
14301617
void xegpu::populateXeGPUSubgroupDistributePatterns(
14311618
RewritePatternSet &patterns) {
14321619
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1433-
LoadNdDistribution, DpasDistribution>(patterns.getContext());
1620+
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
1621+
UpdateNdOffsetDistribution>(patterns.getContext());
1622+
// TODO: Is this the right place to add these patterns?
1623+
patterns.add<GpuIndexOpDistribution<gpu::BlockIdOp>,
1624+
GpuIndexOpDistribution<gpu::BlockDimOp>,
1625+
GpuIndexOpDistribution<gpu::SubgroupIdOp>,
1626+
GpuIndexOpDistribution<gpu::SubgroupSizeOp>,
1627+
GpuIndexOpDistribution<gpu::NumSubgroupsOp>,
1628+
GpuIndexOpDistribution<gpu::ClusterDimOp>,
1629+
GpuIndexOpDistribution<gpu::ClusterDimBlocksOp>,
1630+
GpuIndexOpDistribution<gpu::ClusterIdOp>,
1631+
GpuIndexOpDistribution<gpu::ClusterBlockIdOp>,
1632+
GpuIndexOpDistribution<gpu::GridDimOp>,
1633+
GpuIndexOpDistribution<gpu::ThreadIdOp>,
1634+
GpuIndexOpDistribution<gpu::LaneIdOp>,
1635+
GpuIndexOpDistribution<gpu::GlobalIdOp>>(patterns.getContext());
14341636
}
14351637

14361638
void XeGPUSubgroupDistributePass::runOnOperation() {

mlir/test/Dialect/XeGPU/subgroup-distribution.mlir

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,118 @@ gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64,
160160
gpu.return
161161
}
162162
}
163+
164+
// -----
165+
// CHECK-LABEL: gpu.func @test_update_nd_offset_1d(
166+
// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>) {
167+
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
168+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
169+
// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32] : !xegpu.tensor_desc<16xf32>
170+
// CHECK: xegpu.store_nd %[[CST]], %[[T1]] : vector<1xf32>, !xegpu.tensor_desc<16xf32>
171+
gpu.module @test {
172+
gpu.func @test_update_nd_offset_1d(%arg0: memref<256xf32>){
173+
%c0 = arith.constant 0 : index
174+
%c32 = arith.constant 32 : index
175+
%1 = arith.constant dense<1.000000e+00> : vector<16xf32>
176+
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf32> -> !xegpu.tensor_desc<16xf32>
177+
%2 = xegpu.update_nd_offset %0, [%c32] : !xegpu.tensor_desc<16xf32>
178+
xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
179+
gpu.return
180+
}
181+
}
182+
183+
// -----
184+
// CHECK-LABEL: gpu.func @test_update_nd_offset_2d
185+
// CHECK: %[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf32>) {
186+
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32>
187+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
188+
// CHECK: %[[T1:.*]] = xegpu.update_nd_offset %[[T0]], [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
189+
// CHECK: xegpu.store_nd %[[CST]], %[[T1]] : vector<16xf32>, !xegpu.tensor_desc<16x16xf32>
190+
gpu.module @test {
191+
gpu.func @test_update_nd_offset_2d(%arg0: memref<256x256xf32>){
192+
%c0 = arith.constant 0 : index
193+
%c32 = arith.constant 32 : index
194+
%1 = arith.constant dense<1.000000e+00> : vector<16x16xf32>
195+
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf32> -> !xegpu.tensor_desc<16x16xf32>
196+
%2 = xegpu.update_nd_offset %0, [%c32, %c32] : !xegpu.tensor_desc<16x16xf32>
197+
xegpu.store_nd %1, %2 : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32>
198+
gpu.return
199+
}
200+
}
201+
202+
// -----
203+
// CHECK-LABEL: gpu.func @test_prefetch_2d
204+
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256x256xf16>) {
205+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
206+
// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x16xf16>
207+
gpu.module @test {
208+
gpu.func @test_prefetch_2d(%arg0: memref<256x256xf16>){
209+
%c0 = arith.constant 0 : index
210+
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16>
211+
xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16x16xf16>
212+
gpu.return
213+
}
214+
}
215+
216+
// -----
217+
// CHECK-LABEL: gpu.func @test_prefetch_1d
218+
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
219+
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
220+
// CHECK: xegpu.prefetch_nd %[[T0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16xf16>
221+
gpu.module @test {
222+
gpu.func @test_prefetch_1d(%arg0: memref<256xf16>){
223+
%c0 = arith.constant 0 : index
224+
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<256xf16> -> !xegpu.tensor_desc<16xf16>
225+
xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<16xf16>
226+
gpu.return
227+
}
228+
}
229+
230+
231+
// -----
232+
// CHECK-LABEL: gpu.func @test_gemm_loop
233+
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
234+
// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
235+
// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
236+
// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
237+
// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
238+
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
239+
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
240+
// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
241+
// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
242+
// CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
243+
// CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
244+
// CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
245+
// CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
246+
// CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
247+
// CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
248+
// CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
249+
// CHECK: scf.yield %[[T16]] : vector<8x1xf32>
250+
// CHECK: }
251+
// CHECK: %[[T8:.*]] = xegpu.create_nd_tdesc %[[ARG2]]{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
252+
// CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
253+
// CHECK: xegpu.store_nd %[[T9]], %[[T8]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
254+
gpu.module @test {
255+
gpu.func @test_gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
256+
%c0 = arith.constant 0 : index
257+
%c16 = arith.constant 16 : index
258+
%c8 = arith.constant 8 : index
259+
%c1024 = arith.constant 1024 : index
260+
%0 = gpu.block_id x
261+
%1 = gpu.block_id y
262+
%2 = arith.muli %0, %c8 : index
263+
%3 = arith.muli %1, %c16 : index
264+
%4 = xegpu.create_nd_tdesc %arg2[%2, %3] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
265+
%5 = xegpu.load_nd %4 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
266+
%6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %5) -> (vector<8x16xf32>) {
267+
%7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
268+
%8 = xegpu.create_nd_tdesc %arg1[%arg3, %3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
269+
%9 = xegpu.load_nd %7 : !xegpu.tensor_desc<8x16xbf16> -> vector<8x16xbf16>
270+
%10 = xegpu.load_nd %8 : !xegpu.tensor_desc<16x16xbf16> -> vector<16x16xbf16>
271+
%11 = xegpu.dpas %9, %10, %arg4 : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
272+
scf.yield %11 : vector<8x16xf32>
273+
}
274+
xegpu.store_nd %6, %4 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
275+
gpu.return
276+
}
277+
}

0 commit comments

Comments
 (0)