Skip to content

Commit 521d1a8

Browse files
loisloGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Fix the unpack dim calculation for I4 rewrite with non major_2_minor layouts
PiperOrigin-RevId: 715289044
1 parent 961e5c2 commit 521d1a8

File tree

5 files changed

+91
-7
lines changed

5 files changed

+91
-7
lines changed

xla/service/gpu/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1634,7 +1634,11 @@ cc_library(
16341634
]) + xla_internal(["service:export_hlo"]) + if_google([
16351635
"//xla/hlo/experimental/auto_sharding",
16361636
"//xla/hlo/experimental/auto_sharding:auto_sharding_option",
1637-
]),
1637+
]) + [
1638+
"//xla/tsl/platform:env",
1639+
"//xla/tsl/platform:errors",
1640+
"//xla/tsl/platform:statusor",
1641+
],
16381642
)
16391643

16401644
xla_test(

xla/service/gpu/fusions/triton/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ cc_library(
242242
"//xla/service/llvm_ir:llvm_util",
243243
"//xla/stream_executor:device_description",
244244
"//xla/stream_executor:launch_dim",
245+
"//xla/tsl/platform:errors",
246+
"//xla/tsl/platform:status",
247+
"//xla/tsl/platform:statusor",
245248
"@com_google_absl//absl/algorithm:container",
246249
"@com_google_absl//absl/container:flat_hash_map",
247250
"@com_google_absl//absl/container:flat_hash_set",

xla/service/gpu/fusions/triton/triton_fusion_emitter_int4_device_test.cc

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,48 @@ class PlainInt4ToPackedInt4RewritePassTest : public TritonTest {
9696
};
9797

9898
TEST_F(PlainInt4ToPackedInt4RewritePassTest,
99-
DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales) {
99+
DotWithI4WeightsOnLhsWithNonStandardLayoutAndMultplyInEpilogue) {
100+
constexpr absl::string_view kHloText = R"(
101+
HloModule hlo
102+
103+
fusion {
104+
p_0 = s4[1,128,32]{1,2,0:E(4)} parameter(0)
105+
p_0.1 = s4[1,32,128]{2,1,0:E(4)} bitcast(p_0)
106+
p_0.2 = bf16[1,32,128]{2,1,0} convert(p_0.1)
107+
p_0.3 = bf16[1,128,32]{1,2,0} bitcast(p_0.2)
108+
p_1 = bf16[128,1,64]{2,1,0} parameter(1)
109+
dot = bf16[1,32,64]{2,1,0} dot(p_0.3, p_1),
110+
lhs_batch_dims={0},
111+
lhs_contracting_dims={1},
112+
rhs_batch_dims={1},
113+
rhs_contracting_dims={0}
114+
p_2 = bf16[1,1,32]{2,0,1} parameter(2)
115+
p_2.1 = bf16[1,32]{1,0} bitcast(p_2)
116+
p_2.2 = bf16[1,32,64]{2,1,0} broadcast(p_2.1), dimensions={0,1}
117+
m = bf16[1,32,64]{2,1,0} multiply(dot, p_2.2)
118+
ROOT m.1 = bf16[1,1,32,64]{3,2,1,0} bitcast(m)
119+
}
120+
121+
ENTRY %entry_computation {
122+
p_0 = s4[1,128,32]{1,2,0:E(4)} parameter(0)
123+
p_1 = bf16[128,1,64]{2,1,0} parameter(1)
124+
p_2 = bf16[1,1,32]{2,0,1} parameter(2)
125+
ROOT gemm_fusion_dot.2 = bf16[1,1,32,64]{3,2,1,0} fusion(p_0, p_1, p_2),
126+
kind=kCustom,
127+
calls=fusion,
128+
backend_config={
129+
"fusion_backend_config":{
130+
"kind":"__triton_gemm"
131+
}
132+
}
133+
}
134+
)";
135+
EXPECT_TRUE(RunAndCompareNoHloPasses(
136+
kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5}));
137+
}
138+
139+
TEST_F(PlainInt4ToPackedInt4RewritePassTest,
140+
DotWithInt4WeightsOnLhsFusedWithMultiplyByChannelScales) {
100141
constexpr absl::string_view kHloText = R"(
101142
HloModule DotWithI4WeightsOnLhsFusedWithMultiplyByChannelScales
102143
@@ -133,6 +174,23 @@ TEST_F(PlainInt4ToPackedInt4RewritePassTest,
133174
kHloText, ErrorSpec{/*aabs=*/1e-5, /*arel=*/1e-5}));
134175
}
135176

177+
TEST_F(PlainInt4ToPackedInt4RewritePassTest, NonstandardLayoutInt4) {
178+
constexpr absl::string_view kHloText = R"(
179+
HloModule NonstandardLayoutInt4
180+
181+
ENTRY main {
182+
p0 = s4[64,128]{0,1} parameter(0)
183+
p1 = bf16[256,64]{1,0} parameter(1)
184+
ROOT %dot = bf16[128,256]{1,0} dot(s4[64,128]{0,1} p0, bf16[256,64]{1,0} p1),
185+
lhs_contracting_dims={0},
186+
rhs_contracting_dims={1}
187+
}
188+
)";
189+
190+
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
191+
EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
192+
}
193+
136194
using ::testing::TestParamInfo;
137195
using ::testing::WithParamInterface;
138196

xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ limitations under the License.
6161
#include "xla/hlo/ir/hlo_opcode.h"
6262
#include "xla/hlo/utils/hlo_query.h"
6363
#include "xla/hlo/utils/hlo_traversal.h"
64+
#include "xla/layout.h"
65+
#include "xla/layout_util.h"
6466
#include "xla/literal.h"
6567
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
6668
#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
@@ -73,6 +75,7 @@ limitations under the License.
7375
#include "xla/service/gpu/ir_emission_utils.h"
7476
#include "xla/service/gpu/launch_dimensions.h"
7577
#include "xla/service/gpu/matmul_indexing_utils.h"
78+
#include "xla/service/gpu/matmul_utils.h"
7679
#include "xla/service/gpu/model/tiled_hlo_computation.h"
7780
#include "xla/service/gpu/triton_fusion_analysis.h"
7881
#include "xla/service/gpu/triton_tiling_propagation.h"
@@ -82,6 +85,9 @@ limitations under the License.
8285
#include "xla/status_macros.h"
8386
#include "xla/stream_executor/device_description.h"
8487
#include "xla/stream_executor/launch_dim.h"
88+
#include "xla/tsl/platform/errors.h"
89+
#include "xla/tsl/platform/status.h"
90+
#include "xla/tsl/platform/statusor.h"
8591
#include "xla/util.h"
8692
#include "xla/xla_data.pb.h"
8793
#include "tsl/platform/errors.h"
@@ -1477,7 +1483,8 @@ class MatMulEmitterHelper {
14771483
.getResult());
14781484
if (hlo->shape().element_type() == PrimitiveType::S4 &&
14791485
IsTritonInt4RewritesEnabled(*hlo)) {
1480-
tensor_ptr.getDefiningOp()->setAttr("packed_dim", GetPackedDimAttr(side));
1486+
tensor_ptr.getDefiningOp()->setAttr(
1487+
"packed_dim", GetPackedDimAttr(side, hlo->shape().layout()));
14811488
}
14821489
tensor_ptr = b_.create<mt::AdvanceOp>(tensor_ptr.getType(), tensor_ptr,
14831490
block_offsets);
@@ -1486,16 +1493,22 @@ class MatMulEmitterHelper {
14861493

14871494
// Naive implementation of the packed_dim attribute for the int4 tensors.
14881495
// It doesn't take into account different layout schemes.
1489-
mlir::IntegerAttr GetPackedDimAttr(const Side& side) const {
1496+
mlir::IntegerAttr GetPackedDimAttr(const Side& side,
1497+
const Layout& layout) const {
14901498
int packed_dim = 0;
1499+
const std::vector<int64_t> logical_to_physical =
1500+
LayoutUtil::MakeLogicalToPhysical(layout);
1501+
14911502
if (side.scope == TritonFusionAnalysis::Scope::LHS) {
1492-
if (dims_.lhs_contracting_dim_idx > dims_.lhs_noncontracting_dim_idx) {
1503+
if (logical_to_physical[dims_.lhs_contracting_dim_idx] >
1504+
logical_to_physical[dims_.lhs_noncontracting_dim_idx]) {
14931505
packed_dim = 0;
14941506
} else {
14951507
packed_dim = 1;
14961508
}
14971509
} else if (side.scope == TritonFusionAnalysis::Scope::RHS) {
1498-
if (dims_.rhs_contracting_dim_idx > dims_.rhs_noncontracting_dim_idx) {
1510+
if (logical_to_physical[dims_.rhs_contracting_dim_idx] >
1511+
logical_to_physical[dims_.rhs_noncontracting_dim_idx]) {
14991512
packed_dim = 1;
15001513
} else {
15011514
packed_dim = 0;

xla/service/gpu/gpu_compiler.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ limitations under the License.
263263
#include "xla/stream_executor/platform_manager.h"
264264
#include "xla/stream_executor/semantic_version.h"
265265
#include "xla/stream_executor/stream_executor.h"
266+
#include "xla/tsl/platform/env.h"
267+
#include "xla/tsl/platform/errors.h"
268+
#include "xla/tsl/platform/statusor.h"
269+
#include "xla/tsl/platform/threadpool.h"
266270
#include "xla/util.h"
267271
#include "xla/xla.pb.h"
268272
#include "xla/xla_data.pb.h"
@@ -1552,7 +1556,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
15521556
pipeline.AddPass<GemvRewriter>();
15531557
pipeline.AddPass<GemmFusion>(gpu_version);
15541558
pipeline.AddPass<GemmFusionSwapOperands>();
1555-
pipeline.AddPass<SimplifyInt4Dots>();
1559+
if (!debug_options.xla_gpu_experimental_enable_triton_i4_rewrites()) {
1560+
pipeline.AddPass<SimplifyInt4Dots>();
1561+
}
15561562
} else if (cuda_cc != nullptr &&
15571563
cuda_cc->major == se::CudaComputeCapability::VOLTA) {
15581564
// Greedy pattern matching for custom kernel fusions.

0 commit comments

Comments
 (0)