Skip to content

Commit 71cc33b

Browse files
Merge OpenAI Triton commit 19eef7c (#4801)
This PR change the Triton base from 815b2a4 to 19eef7c (Jul 18). Pass rate: 98.62%
2 parents bae3356 + fb0b569 commit 71cc33b

File tree

23 files changed

+860
-227
lines changed

23 files changed

+860
-227
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,39 +2486,6 @@ struct TritonGPUInferLayoutInterface
24862486
Attribute srcEnc,
24872487
ArrayRef<int64_t> dstShape,
24882488
Attribute &dstEnc) const {
2489-
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
2490-
// TODO: supporting reshape of CTA layouts is non-trivial.
2491-
if (getNumCTAs(mmaEncoding) > 1)
2492-
return failure();
2493-
int innerDimDst =
2494-
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
2495-
int innerDimSrc =
2496-
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
2497-
// For now disallow reshape of the inner dimension.
2498-
if (innerDimDst != innerDimSrc)
2499-
return failure();
2500-
auto *ctx = srcEnc.getContext();
2501-
2502-
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
2503-
auto CTALayout = CTALayoutAttr::get(
2504-
ctx,
2505-
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
2506-
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
2507-
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
2508-
dstEnc = NVMMASharedEncodingAttr::get(
2509-
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
2510-
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
2511-
CTALayout);
2512-
// Big guns, check linear layouts are equivalent
2513-
// We disallow reshaping memdesc_subviews in the verifier
2514-
// We disallow reshaping memdesc_subviews in the verifier
2515-
auto srcLL = toLinearLayout(srcShape, srcEnc, srcShape);
2516-
auto dstLL = toLinearLayout(dstShape, dstEnc, dstShape);
2517-
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
2518-
return failure();
2519-
}
2520-
return success();
2521-
}
25222489
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
25232490
if (!src) {
25242491
return failure();

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,46 @@ LogicalResult MemDescReshapeOp::verify() {
496496
return success();
497497
}
498498

499+
static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
500+
Attribute srcEnc,
501+
ArrayRef<int64_t> dstShape,
502+
Attribute &dstEnc) {
503+
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
504+
// TODO: supporting reshape of CTA layouts is non-trivial.
505+
if (getNumCTAs(mmaEncoding) > 1)
506+
return failure();
507+
int innerDimDst =
508+
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
509+
int innerDimSrc =
510+
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
511+
// For now disallow reshape of the inner dimension.
512+
if (innerDimDst != innerDimSrc)
513+
return failure();
514+
auto *ctx = srcEnc.getContext();
515+
516+
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
517+
auto CTALayout = CTALayoutAttr::get(
518+
ctx,
519+
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
520+
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),
521+
/*CTAOrder=*/llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
522+
dstEnc = NVMMASharedEncodingAttr::get(
523+
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
524+
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
525+
CTALayout);
526+
// Big guns, check linear layouts are equivalent
527+
// We disallow reshaping memdesc_subviews in the verifier
528+
// We disallow reshaping memdesc_subviews in the verifier
529+
auto srcLL = toLinearLayout(srcShape, srcEnc, srcShape);
530+
auto dstLL = toLinearLayout(dstShape, dstEnc, dstShape);
531+
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
532+
return failure();
533+
}
534+
return success();
535+
}
536+
return failure();
537+
}
538+
499539
LogicalResult MemDescReshapeOp::inferReturnTypes(
500540
MLIRContext *context, std::optional<Location> loc, MemDescType srcTy,
501541
ArrayRef<int64_t> dstShape, MemDescType &inferredReturnType) {
@@ -505,9 +545,8 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(
505545

506546
Attribute dstEncoding;
507547
if (Attribute srcEnc = srcTy.getEncoding()) {
508-
auto *inferLayout = cast<DialectInferLayoutInterface>(&srcEnc.getDialect());
509-
if (failed(inferLayout->inferReshapeOpEncoding(srcTy.getShape(), srcEnc,
510-
dstShape, dstEncoding, loc)))
548+
if (failed(inferMemDescReshapeOpEncoding(srcTy.getShape(), srcEnc, dstShape,
549+
dstEncoding)))
511550
return failure();
512551
}
513552

python/src/ir.cc

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ ReproducerStreamFactory makeConsoleReproducer() {
152152
OpPrintingFlags getOpPrintingFlags() {
153153
auto printingFlags = OpPrintingFlags();
154154
printingFlags.enableDebugInfo();
155+
printingFlags.printNameLocAsPrefix(true);
155156
return printingFlags;
156157
}
157158

@@ -372,11 +373,15 @@ void init_triton_ir(py::module &&m) {
372373
self.replaceAllUsesWith(newValue);
373374
})
374375
.def("get_type", &Value::getType)
375-
.def("id", [](Value &self) {
376-
// The Value is identified by and compared with
377-
// other Values via the underlying ValueImpl
378-
return (uint64_t)self.getImpl();
379-
});
376+
.def("id",
377+
[](Value &self) {
378+
// The Value is identified by and compared with
379+
// other Values via the underlying ValueImpl
380+
return (uint64_t)self.getImpl();
381+
})
382+
.def("set_loc",
383+
[](Value &self, Location loc) { return self.setLoc(loc); })
384+
.def("get_loc", [](Value &self) { return self.getLoc(); });
380385

381386
py::class_<OpResult, Value>(m, "op_result", py::module_local());
382387

@@ -929,6 +934,28 @@ void init_triton_ir(py::module &&m) {
929934
// locs
930935
.def("set_loc",
931936
[](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); })
937+
.def("set_loc",
938+
[](TritonOpBuilder &self, std::string name) {
939+
auto nameAttr = StringAttr::get(self.getContext(), name);
940+
auto loc = NameLoc::get(nameAttr);
941+
self.setLastLoc(loc);
942+
})
943+
.def("create_loc",
944+
[](TritonOpBuilder &self, const std::string &fileName, int line,
945+
int column) -> Location {
946+
return mlir::FileLineColLoc::get(self.getContext(), fileName, line,
947+
column);
948+
})
949+
.def(
950+
"create_name_loc",
951+
[](TritonOpBuilder &self, std::string name,
952+
std::optional<Location> childLoc) -> Location {
953+
auto nameAttr = StringAttr::get(self.getContext(), name);
954+
if (childLoc)
955+
return NameLoc::get(nameAttr, *childLoc);
956+
return NameLoc::get(nameAttr);
957+
},
958+
py::arg("name"), py::arg("child_loc") = py::none())
932959
.def("set_loc",
933960
[](TritonOpBuilder &self, const std::string &fileName, int line,
934961
int column) { self.setLastLoc(fileName, line, column); })

0 commit comments

Comments
 (0)