Skip to content

Commit 250566a

Browse files
authored
Merge OpenAI commit 8b792c8 (#5100)
This PR change the Triton base from 33b2823 to 8b792c8 (Sep 8). Pass rate: 98.11%
2 parents cf519ef + 043818e commit 250566a

File tree

23 files changed

+1020
-217
lines changed

23 files changed

+1020
-217
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ if(TRITON_BUILD_PYTHON_MODULE)
307307
${PYTHON_SRC_PATH}/gluon_ir.cc
308308
${PYTHON_SRC_PATH}/passes.cc
309309
${PYTHON_SRC_PATH}/interpreter.cc
310-
${PYTHON_SRC_PATH}/llvm.cc)
310+
${PYTHON_SRC_PATH}/llvm.cc
311+
${PYTHON_SRC_PATH}/specialize.cc)
311312

312313
# Link triton with its dependencies
313314
target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES})

python/src/gluon_ir.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ struct GluonLayouts {
9797
py::handle NVMMASharedLayout;
9898
py::handle SwizzledSharedLayout;
9999
py::handle AMDMFMALayout;
100+
py::handle AMDWMMALayout;
100101
py::handle PaddedSharedLayout;
101102
py::handle GluonDType;
102103

@@ -117,6 +118,7 @@ struct GluonLayouts {
117118
SwizzledSharedLayout =
118119
py::object(layouts.attr("SwizzledSharedLayout")).release();
119120
AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release();
121+
AMDWMMALayout = py::object(amdLayouts.attr("AMDWMMALayout")).release();
120122
PaddedSharedLayout =
121123
py::object(layouts.attr("PaddedSharedLayout")).release();
122124

@@ -226,6 +228,14 @@ py::object layoutToGluon(Attribute layout) {
226228
toStdVector(ctaLayout.getCTAsPerCGA()),
227229
toStdVector(ctaLayout.getCTASplitNum()),
228230
toStdVector(ctaLayout.getCTAOrder()));
231+
} else if (auto amdWmma = dyn_cast<ttg::AMDWmmaEncodingAttr>(layout)) {
232+
auto ctaLayout = amdWmma.getCTALayout();
233+
return layouts.AMDWMMALayout(amdWmma.getVersion(),
234+
amdWmma.getIsTransposed(),
235+
toStdVector(amdWmma.getWarpsPerCTA()),
236+
toStdVector(ctaLayout.getCTAsPerCGA()),
237+
toStdVector(ctaLayout.getCTASplitNum()),
238+
toStdVector(ctaLayout.getCTAOrder()));
229239
} else if (auto paddedShared =
230240
dyn_cast<ttg::PaddedSharedEncodingAttr>(layout)) {
231241
auto *ctx = paddedShared.getContext();
@@ -357,6 +367,18 @@ void init_gluon_ir(py::module &&m) {
357367
ctx, version, warpsPerCta, tilesPerWarp, instrShape[0],
358368
instrShape[1], transposed, ctaLayout, elemType);
359369
})
370+
.def("get_amd_wmma_layout",
371+
[](GluonOpBuilder &self, unsigned version, bool transposed,
372+
std::vector<unsigned> &warpsPerCta,
373+
std::vector<unsigned> &ctasPerCga,
374+
std::vector<unsigned> &ctaSplitNum,
375+
std::vector<unsigned> &ctaOrder) -> Attribute {
376+
auto ctx = self.getContext();
377+
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
378+
ctx, ctasPerCga, ctaSplitNum, ctaOrder);
379+
return ttg::AMDWmmaEncodingAttr::get(ctx, version, transposed,
380+
warpsPerCta, ctaLayout);
381+
})
360382
.def("get_padded_shared_layout",
361383
[](GluonOpBuilder &self, std::vector<unsigned> &intervals,
362384
std::vector<unsigned> &paddings,

python/src/main.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ void init_triton_interpreter(pybind11::module &&m);
4343
void init_triton_passes(pybind11::module &&m);
4444
void init_triton_stacktrace_hook(pybind11::module &m);
4545
void init_gluon_ir(pybind11::module &&m);
46+
void init_native_specialize(pybind11::module &m);
4647
FOR_EACH_P(DECLARE_BACKEND, TRITON_BACKENDS_TUPLE)
4748

4849
PYBIND11_MODULE(libtriton, m) {
4950
m.doc() = "Python bindings to the C++ Triton API";
5051
init_triton_stacktrace_hook(m);
5152
init_triton_env_vars(m);
53+
init_native_specialize(m);
5254
init_triton_ir(m.def_submodule("ir"));
5355
init_triton_passes(m.def_submodule("passes"));
5456
init_triton_interpreter(m.def_submodule("interpreter"));

0 commit comments

Comments
 (0)