Skip to content

Commit be9504c

Browse files
mooskaghGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Require packed dot operands to be packed along contracting dimension.
For now, only do that if `--xla_gpu_experimental_pack_dot_operands_along_k_dimension` is set. PiperOrigin-RevId: 715355925
1 parent 34acbf0 commit be9504c

File tree

7 files changed

+164
-19
lines changed

7 files changed

+164
-19
lines changed

xla/debug_options_flags.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
323323
opts.set_xla_pjrt_allow_auto_layout_in_hlo(false);
324324
opts.set_xla_gpu_enable_scatter_determinism_expander(true);
325325
opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(false);
326+
opts.set_xla_gpu_experimental_pack_dot_operands_along_k_dimension(false);
326327
return opts;
327328
}
328329

@@ -2230,6 +2231,13 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
22302231
"Enable windowed einsum rewrite for all-to-all+gemm pattern, "
22312232
"This optimization slices the all-to-all into smaller all-to-alls."
22322233
"It is an experimental feature."));
2234+
flag_list->push_back(tsl::Flag(
2235+
"xla_gpu_experimental_pack_dot_operands_along_k_dimension",
2236+
bool_setter_for(
2237+
&DebugOptions::
2238+
set_xla_gpu_experimental_pack_dot_operands_along_k_dimension),
2239+
debug_options->xla_gpu_experimental_pack_dot_operands_along_k_dimension(),
2240+
"For sub-byte dot operands, layout them along contracting dimensions."));
22332241
} // NOLINT(readability/fn_size)
22342242

22352243
// Allocates flag_values and flag_objects; this function must not be called more

xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4136,18 +4136,22 @@ HloModule m
41364136
ENTRY e {
41374137
parameter_0 = bf16[32,4,36]{2,1,0} parameter(0)
41384138
parameter_1 = bf16[40,4,36]{2,1,0} parameter(1)
4139-
ROOT dot.16450 = bf16[4,32,40]{2,1,0} dot(parameter_0, parameter_1), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2}
4139+
ROOT dot.16450 = bf16[4,32,40]{2,1,0} dot(parameter_0, parameter_1),
4140+
lhs_batch_dims={1}, lhs_contracting_dims={2},
4141+
rhs_batch_dims={1}, rhs_contracting_dims={2}
41404142
})";
41414143

41424144
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
41434145
GetOptimizedModule(kHloText));
41444146

4147+
// The contracting dims were already minor, so the layout is unchanged
4148+
// (non-major batch dims are fine).
41454149
EXPECT_THAT(module->entry_computation()
41464150
->root_instruction()
41474151
->fused_instructions_computation()
41484152
->root_instruction(),
4149-
GmockMatch(m::Dot(m::Op().WithShape(BF16, {32, 4, 36}, {2, 0, 1}),
4150-
m::Op().WithShape(BF16, {40, 4, 36}, {2, 0, 1}))
4153+
GmockMatch(m::Dot(m::Op().WithShape(BF16, {32, 4, 36}, {2, 1, 0}),
4154+
m::Op().WithShape(BF16, {40, 4, 36}, {2, 1, 0}))
41514155
.WithShape(BF16, {4, 32, 40}, {2, 1, 0})));
41524156
}
41534157

@@ -4161,18 +4165,22 @@ HloModule m
41614165
ENTRY e {
41624166
parameter_1 = bf16[16,16,48]{2,1,0} parameter(1)
41634167
parameter_2 = bf16[16,48,32]{2,1,0} parameter(0)
4164-
ROOT dot.16125 = bf16[16,16,32]{2,1,0} dot(parameter_1, parameter_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
4168+
ROOT dot.16125 = bf16[16,16,32]{2,1,0} dot(parameter_1, parameter_2),
4169+
lhs_batch_dims={1}, lhs_contracting_dims={2},
4170+
rhs_batch_dims={0}, rhs_contracting_dims={1}
41654171
})";
41664172

41674173
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
41684174
GetOptimizedModule(kHloText));
41694175

4176+
// lhs has minor contracting dims, so the layout is changed.
4177+
// rhs changes layout to have minor contracting dims.
41704178
EXPECT_THAT(
41714179
module->entry_computation()
41724180
->root_instruction()
41734181
->fused_instructions_computation()
41744182
->root_instruction(),
4175-
GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 16, 48}, {2, 0, 1}),
4183+
GmockMatch(m::Dot(m::Op().WithShape(BF16, {16, 16, 48}, {2, 1, 0}),
41764184
m::Op().WithShape(BF16, {16, 48, 32}, {1, 2, 0}))
41774185
.WithShape(BF16, {16, 16, 32}, {2, 1, 0})));
41784186
}

xla/service/gpu/transforms/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,6 +2167,7 @@ cc_library(
21672167
"//xla/stream_executor:device_description",
21682168
"//xla/stream_executor:dnn",
21692169
"//xla/tsl/util:env_var",
2170+
"@com_google_absl//absl/algorithm:container",
21702171
"@com_google_absl//absl/log",
21712172
"@com_google_absl//absl/log:check",
21722173
"@com_google_absl//absl/status",
@@ -2188,6 +2189,7 @@ xla_cc_test(
21882189
"//xla/hlo/ir:hlo",
21892190
"//xla/hlo/parser:hlo_parser",
21902191
"//xla/hlo/testlib:filecheck",
2192+
"//xla/hlo/testlib:pattern_matcher_gmock",
21912193
"//xla/service:computation_layout",
21922194
"//xla/service:pattern_matcher",
21932195
"//xla/service:pattern_matcher_gmock",

xla/service/gpu/transforms/layout_assignment.cc

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <variant>
2525
#include <vector>
2626

27+
#include "absl/algorithm/container.h"
2728
#include "absl/log/check.h"
2829
#include "absl/log/log.h"
2930
#include "absl/status/status.h"
@@ -329,13 +330,21 @@ bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
329330
.ok();
330331
}
331332

333+
bool IsPackedInstruction(const HloInstruction* instruction) {
334+
return primitive_util::IsSubByteNonPredType(
335+
instruction->shape().element_type()) ||
336+
(instruction->opcode() == HloOpcode::kConvert &&
337+
primitive_util::IsSubByteNonPredType(
338+
instruction->operand(0)->shape().element_type()));
339+
}
340+
332341
} // namespace
333342

334343
absl::Status GpuLayoutAssignment::AddDotBackendConstraints(
335344
LayoutConstraints* constraints, HloDotInstruction* instruction) {
336345
struct Side {
337346
size_t operand_no;
338-
const Shape* shape;
347+
const HloInstruction* operand;
339348
absl::Span<const int64_t> batch_dims;
340349
absl::Span<const int64_t> contracting_dims;
341350
PrimitiveType type;
@@ -344,12 +353,13 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints(
344353
auto make_side =
345354
[&](size_t operand_no, absl::Span<const int64_t> batch_dims,
346355
absl::Span<const int64_t> contracting_dims) -> absl::StatusOr<Side> {
347-
Side side = {operand_no, &instruction->operand(operand_no)->shape(),
348-
batch_dims, contracting_dims};
349-
side.type = side.shape->element_type();
350-
TF_ASSIGN_OR_RETURN(side.non_contracting_dims,
351-
GetNonContractingDims(*side.shape, side.batch_dims,
352-
side.contracting_dims));
356+
Side side = {operand_no, instruction->operand(operand_no), batch_dims,
357+
contracting_dims};
358+
side.type = side.operand->shape().element_type();
359+
TF_ASSIGN_OR_RETURN(
360+
side.non_contracting_dims,
361+
GetNonContractingDims(side.operand->shape(), side.batch_dims,
362+
side.contracting_dims));
353363
return side;
354364
};
355365
const DotDimensionNumbers& dot_dims = instruction->dot_dimension_numbers();
@@ -372,6 +382,11 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints(
372382
->config()
373383
.debug_options()
374384
.xla_gpu_ensure_minor_dot_contraction_dims();
385+
const bool pack_along_contracting_dims =
386+
instruction->GetModule()
387+
->config()
388+
.debug_options()
389+
.xla_gpu_experimental_pack_dot_operands_along_k_dimension();
375390

376391
const bool is_bf16_to_bf16 =
377392
(output_type == PrimitiveType::BF16 && lhs.type == PrimitiveType::BF16 &&
@@ -388,11 +403,11 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints(
388403
is_s8_to_s32 || is_fp8_to_fp8;
389404

390405
for (const Side& side : {lhs, rhs}) {
391-
if (both_operands_require_minor_contraction_dims) {
392-
TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout(
393-
instruction, side.operand_no,
394-
/*dim_groups=*/
395-
{side.batch_dims, side.non_contracting_dims, side.contracting_dims}));
406+
if ((IsPackedInstruction(side.operand) && pack_along_contracting_dims) ||
407+
both_operands_require_minor_contraction_dims) {
408+
TF_RETURN_IF_ERROR(SetDotOperandLayoutToMinorContracting(
409+
instruction, side.operand_no, side.batch_dims, side.contracting_dims,
410+
side.non_contracting_dims));
396411
} else if (!side.batch_dims.empty() || side.contracting_dims.size() > 1 ||
397412
side.non_contracting_dims.size() > 1) {
398413
TF_RETURN_IF_ERROR(SetDotOperandLayout(
@@ -571,6 +586,42 @@ absl::Status GpuLayoutAssignment::SetDotOperandLayout(
571586
/*dim_groups=*/{batch_dims, row_dims, col_dims});
572587
}
573588

589+
absl::Status GpuLayoutAssignment::SetDotOperandLayoutToMinorContracting(
590+
const HloInstruction* instruction, int64_t operand,
591+
absl::Span<const int64_t> batch_dims,
592+
absl::Span<const int64_t> contracting_dims,
593+
absl::Span<const int64_t> noncontracting_dims) {
594+
Shape shape = instruction->operand(operand)->shape();
595+
596+
if (shape.has_layout() &&
597+
shape.layout().minor_to_major_size() >= contracting_dims.size()) {
598+
// Check that the contracting dimensions are physically minor, i.e. check
599+
// that minor physical dimensions all point to contracting logical
600+
// dimensions.
601+
bool contracting_dims_are_minor = true;
602+
const auto& minor_to_major = shape.layout().minor_to_major();
603+
for (int64_t i = 0; i < contracting_dims.size(); ++i) {
604+
if (!absl::c_linear_search(contracting_dims, minor_to_major[i])) {
605+
contracting_dims_are_minor = false;
606+
break;
607+
}
608+
}
609+
610+
// If contracting dims are already minor, and the layout is valid, keep it.
611+
if (contracting_dims_are_minor &&
612+
MatrixLayout::For(shape, batch_dims, noncontracting_dims,
613+
contracting_dims)
614+
.ok()) {
615+
// Re-set the operand layout, so it becomes mandatory.
616+
return SetOperandLayout(shape, instruction, operand);
617+
}
618+
}
619+
return SetOperandMajorToMinorLayout(
620+
instruction, operand,
621+
/*dim_groups=*/
622+
{batch_dims, noncontracting_dims, contracting_dims});
623+
}
624+
574625
absl::Status GpuLayoutAssignment::SetOperandMajorToMinorLayout(
575626
const HloInstruction* instruction, int64_t operand,
576627
std::initializer_list<absl::Span<const int64_t>> dim_groups) {

xla/service/gpu/transforms/layout_assignment.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ class GpuLayoutAssignment : public LayoutAssignment {
6565
absl::Span<const int64_t> row_dims,
6666
absl::Span<const int64_t> col_dims);
6767

68+
absl::Status SetDotOperandLayoutToMinorContracting(
69+
const HloInstruction* instruction, int64_t operand,
70+
absl::Span<const int64_t> batch_dims,
71+
absl::Span<const int64_t> contracting_dims,
72+
absl::Span<const int64_t> noncontracting_dims);
73+
6874
absl::Status SetDotLayout(const HloInstruction* instruction,
6975
LayoutConstraints* constraints);
7076

xla/service/gpu/transforms/layout_assignment_test.cc

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ limitations under the License.
2727
#include "xla/hlo/ir/hlo_opcode.h"
2828
#include "xla/hlo/parser/hlo_parser.h"
2929
#include "xla/hlo/testlib/filecheck.h"
30+
#include "xla/hlo/testlib/pattern_matcher_gmock.h"
3031
#include "xla/layout.h"
3132
#include "xla/layout_util.h"
3233
#include "xla/service/computation_layout.h"
3334
#include "xla/service/gpu/stream_executor_util.h"
3435
#include "xla/service/pattern_matcher.h"
35-
#include "xla/service/pattern_matcher_gmock.h"
3636
#include "xla/shape.h"
3737
#include "xla/shape_layout.h"
3838
#include "xla/shape_util.h"
@@ -770,6 +770,73 @@ TEST_F(LayoutAssignmentTest, AutoLayoutE4M3ContractingMinorFirst) {
770770
.WithShape(F32, {128, 10240}, {1, 0})));
771771
}
772772

773+
TEST_F(LayoutAssignmentTest, AutoLayoutS4DotContractingMinorLhs) {
774+
const char* hlo = R"(
775+
HloModule AutoLayoutS4DotContractingMinorLhs
776+
777+
ENTRY main {
778+
p0 = s4[5120,128] parameter(0)
779+
p0.c = bf16[5120,128] convert(p0)
780+
p1 = bf16[5120,10240] parameter(1)
781+
ROOT dot = bf16[128,10240] dot(p0.c, p1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
782+
})";
783+
TF_ASSERT_OK_AND_ASSIGN(
784+
std::unique_ptr<HloModule> m,
785+
ParseAndReturnUnverifiedModule(
786+
hlo, {}, HloParserOptions().set_fill_missing_layouts(false)));
787+
DebugOptions debug_options = m->config().debug_options();
788+
debug_options.set_xla_gpu_experimental_pack_dot_operands_along_k_dimension(
789+
true);
790+
m->mutable_config().set_debug_options(debug_options);
791+
ComputationLayout computation_layout(
792+
m->entry_computation()->ComputeProgramShape(),
793+
/*ignore_layouts=*/false);
794+
GpuLayoutAssignment layout_assignment(
795+
&computation_layout, GetGpuComputeCapability(), GetDnnVersion(),
796+
GetDeviceDescription());
797+
EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
798+
EXPECT_THAT(m->entry_computation()->parameter_instruction(0),
799+
GmockMatch(m::Parameter(0).WithShape(S4, {5120, 128}, {0, 1})));
800+
EXPECT_THAT(
801+
m->entry_computation()->parameter_instruction(1),
802+
GmockMatch(m::Parameter(1).WithShape(BF16, {5120, 10240}, {1, 0})));
803+
EXPECT_THAT(m->entry_computation()->root_instruction(),
804+
GmockMatch(m::Dot().WithShape(BF16, {128, 10240}, {1, 0})));
805+
}
806+
807+
TEST_F(LayoutAssignmentTest, AutoLayoutS4DotContractingMinorRhs) {
808+
const char* hlo = R"(
809+
HloModule AutoLayoutS4DotContractingMinorRhs
810+
811+
ENTRY main {
812+
p0 = bf16[5120,128] parameter(0)
813+
p1 = s4[5120,10240] parameter(1)
814+
p1.c = bf16[5120,10240] convert(p1)
815+
ROOT dot = bf16[128,10240] dot(p0, p1.c), lhs_contracting_dims={0}, rhs_contracting_dims={0}
816+
})";
817+
TF_ASSERT_OK_AND_ASSIGN(
818+
std::unique_ptr<HloModule> m,
819+
ParseAndReturnUnverifiedModule(
820+
hlo, {}, HloParserOptions().set_fill_missing_layouts(false)));
821+
DebugOptions debug_options = m->config().debug_options();
822+
debug_options.set_xla_gpu_experimental_pack_dot_operands_along_k_dimension(
823+
true);
824+
m->mutable_config().set_debug_options(debug_options);
825+
ComputationLayout computation_layout(
826+
m->entry_computation()->ComputeProgramShape(),
827+
/*ignore_layouts=*/false);
828+
GpuLayoutAssignment layout_assignment(
829+
&computation_layout, GetGpuComputeCapability(), GetDnnVersion(),
830+
GetDeviceDescription());
831+
EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true));
832+
EXPECT_THAT(m->entry_computation()->parameter_instruction(0),
833+
GmockMatch(m::Parameter(0).WithShape(BF16, {5120, 128}, {1, 0})));
834+
EXPECT_THAT(m->entry_computation()->parameter_instruction(1),
835+
GmockMatch(m::Parameter(1).WithShape(S4, {5120, 10240}, {0, 1})));
836+
EXPECT_THAT(m->entry_computation()->root_instruction(),
837+
GmockMatch(m::Dot().WithShape(BF16, {128, 10240}, {1, 0})));
838+
}
839+
773840
TEST_F(LayoutAssignmentTest, VariadicReduceSameOperandLayout) {
774841
const char* module_str = R"(
775842
HloModule variadic_reduce

xla/xla.proto

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,10 @@ message DebugOptions {
11131113
// xla_gpu_multi_streamed_windowed_einsum is set to true.
11141114
bool xla_gpu_experimental_enable_alltoall_windowed_einsum = 360;
11151115

1116-
// Next id: 362
1116+
// For sub-byte dot operands, layout them along contracting dimensions.
1117+
bool xla_gpu_experimental_pack_dot_operands_along_k_dimension = 362;
1118+
1119+
// Next id: 363
11171120

11181121
// Extra options to pass to the compilation backend (e.g. LLVM); specific
11191122
// interpretation of these values is left to the backend.

0 commit comments

Comments
 (0)