Skip to content

Commit b80c9d6

Browse files
[MemoryOpToLLVM] Sync from upstream
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 0d118de commit b80c9d6

File tree

2 files changed

+29
-50
lines changed

2 files changed

+29
-50
lines changed

scripts/skiplist/lts/language.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,3 +1573,15 @@ test/unit/language/test_core.py::test_scaled_dot[128-128-128-False-False-True-e4
15731573
test/unit/language/test_core.py::test_trans_reshape
15741574
test/unit/language/test_core.py::test_dot[1-64-128-128-4-False-True-none-tf32-float32-float32-1_0]
15751575
test/unit/language/test_core.py::test_dot[1-64-128-128-4-False-True-none-tf32-float32-float32-1_1]
1576+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-True-none-tf32-float16-float16-1_0]
1577+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-True-none-tf32-float16-float16-1_1]
1578+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-True-none-tf32-float32-float32-1_0]
1579+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-True-none-tf32-float32-float32-1_1]
1580+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-False-none-tf32-float16-float16-1_0]
1581+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-False-none-tf32-float16-float16-1_1]
1582+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-False-none-tf32-float32-float32-1_0]
1583+
test/unit/language/test_core.py::test_dot[1-64-128-128-4-True-False-none-tf32-float32-float32-1_1]
1584+
test/unit/language/test_core.py::test_dot[1-128-128-64-4-True-True-none-tf32-float16-float16-1_0]
1585+
test/unit/language/test_core.py::test_dot[1-128-128-64-4-True-True-none-tf32-float16-float16-1_1]
1586+
test/unit/language/test_core.py::test_dot[1-128-128-64-4-True-False-none-tf32-float16-float16-1_0]
1587+
test/unit/language/test_core.py::test_dot[1-128-128-64-4-True-False-none-tf32-float16-float16-1_1]

third_party/intel/lib/TritonIntelGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 17 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -131,20 +131,27 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
131131
LogicalResult
132132
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
133133
ConversionPatternRewriter &rewriter) const override {
134-
MemDescType srcTy = op.getSrc().getType();
135134
RankedTensorType dstTy = op.getType();
136-
Attribute srcLayout = srcTy.getEncoding();
137135
Attribute dstLayout = dstTy.getEncoding();
138-
if (isa<SharedEncodingAttr>(srcLayout) &&
139-
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
140-
dstLayout)) {
141-
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
142-
rewriter);
143-
}
144136
if (isa<DotOperandEncodingAttr>(dstLayout)) {
145-
return lowerSharedToDotOperand(op, adaptor, getTypeConverter(), rewriter);
137+
auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
138+
if (auto dpasLayout =
139+
dyn_cast_or_null<DpasEncodingAttr>(dotLayout.getParent())) {
140+
auto sharedLayout =
141+
cast<SharedEncodingAttr>(op.getSrc().getType().getEncoding());
142+
int K;
143+
if (dotLayout.getOpIdx() == 0) // $a
144+
K = op.getType().getShape()[sharedLayout.getOrder()[0]];
145+
else // $b
146+
K = op.getType().getShape()[sharedLayout.getOrder()[1]];
147+
bool isOuter = K == 1;
148+
rewriter.replaceOp(op, lowerSharedToDotOperandDPAS(
149+
op, adaptor, getTypeConverter(), rewriter,
150+
dpasLayout, dotLayout, isOuter));
151+
return success();
152+
}
146153
}
147-
return failure();
154+
return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter);
148155
}
149156

150157
private:
@@ -174,53 +181,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
174181
return res;
175182
}
176183

177-
LogicalResult
178-
lowerSharedToDotOperand(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
179-
const LLVMTypeConverter *typeConverter,
180-
ConversionPatternRewriter &rewriter) const {
181-
auto loc = op.getLoc();
182-
RankedTensorType dstTy = op.getType();
183-
Attribute dstLayout = dstTy.getEncoding();
184-
auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
185-
auto sharedLayout =
186-
cast<SharedEncodingAttr>(op.getSrc().getType().getEncoding());
187-
188-
int K;
189-
if (dotLayout.getOpIdx() == 0) // $a
190-
K = op.getType().getShape()[sharedLayout.getOrder()[0]];
191-
else // $b
192-
K = op.getType().getShape()[sharedLayout.getOrder()[1]];
193-
bool isOuter = K == 1;
194-
195-
Value res;
196-
if (auto dpasLayout =
197-
dyn_cast_or_null<DpasEncodingAttr>(dotLayout.getParent())) {
198-
res = lowerSharedToDotOperandDPAS(op, adaptor, typeConverter, rewriter,
199-
dpasLayout, dotLayout, isOuter);
200-
} else if (auto blockedLayout = dyn_cast_or_null<BlockedEncodingAttr>(
201-
dotLayout.getParent())) {
202-
auto thread = getThreadId(rewriter, loc);
203-
res = SharedToDotOperandFMA::convertLayout(
204-
dotLayout.getOpIdx(), op.getSrc(), adaptor.getSrc(), blockedLayout,
205-
thread, loc, getTypeConverter(), rewriter);
206-
} else {
207-
assert(false && "Unsupported dot operand layout found");
208-
}
209-
210-
rewriter.replaceOp(op, res);
211-
return success();
212-
}
213184
LogicalResult
214185
lowerSharedToDistributed(LocalLoadOp op, LocalLoadOpAdaptor adaptor,
215186
const LLVMTypeConverter *typeConverter,
216187
ConversionPatternRewriter &rewriter) const {
217188
auto loc = op.getLoc();
218189
auto srcTy = op.getSrc().getType();
219190
auto dstTy = op.getResult().getType();
220-
auto dstShape = dstTy.getShape();
221-
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
222-
assert(!isa<DotOperandEncodingAttr>(dstTy.getEncoding()) &&
223-
"Unexpected rank of ConvertLayout(shared->blocked)");
224191

225192
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
226193
loc, adaptor.getSrc(),

0 commit comments

Comments
 (0)