@@ -24,6 +24,7 @@ limitations under the License.
2424#include < variant>
2525#include < vector>
2626
27+ #include " absl/algorithm/container.h"
2728#include " absl/log/check.h"
2829#include " absl/log/log.h"
2930#include " absl/status/status.h"
@@ -329,13 +330,21 @@ bool DotCanSupportShapeWithLayout(const HloInstruction* dot,
329330 .ok ();
330331}
331332
333+ bool IsPackedInstruction (const HloInstruction* instruction) {
334+ return primitive_util::IsSubByteNonPredType (
335+ instruction->shape ().element_type ()) ||
336+ (instruction->opcode () == HloOpcode::kConvert &&
337+ primitive_util::IsSubByteNonPredType (
338+ instruction->operand (0 )->shape ().element_type ()));
339+ }
340+
332341} // namespace
333342
334343absl::Status GpuLayoutAssignment::AddDotBackendConstraints (
335344 LayoutConstraints* constraints, HloDotInstruction* instruction) {
336345 struct Side {
337346 size_t operand_no;
338- const Shape* shape ;
347+ const HloInstruction* operand ;
339348 absl::Span<const int64_t > batch_dims;
340349 absl::Span<const int64_t > contracting_dims;
341350 PrimitiveType type;
@@ -344,12 +353,13 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints(
344353 auto make_side =
345354 [&](size_t operand_no, absl::Span<const int64_t > batch_dims,
346355 absl::Span<const int64_t > contracting_dims) -> absl::StatusOr<Side> {
347- Side side = {operand_no, &instruction->operand (operand_no)->shape (),
348- batch_dims, contracting_dims};
349- side.type = side.shape ->element_type ();
350- TF_ASSIGN_OR_RETURN (side.non_contracting_dims ,
351- GetNonContractingDims (*side.shape , side.batch_dims ,
352- side.contracting_dims ));
356+ Side side = {operand_no, instruction->operand (operand_no), batch_dims,
357+ contracting_dims};
358+ side.type = side.operand ->shape ().element_type ();
359+ TF_ASSIGN_OR_RETURN (
360+ side.non_contracting_dims ,
361+ GetNonContractingDims (side.operand ->shape (), side.batch_dims ,
362+ side.contracting_dims ));
353363 return side;
354364 };
355365 const DotDimensionNumbers& dot_dims = instruction->dot_dimension_numbers ();
@@ -372,6 +382,11 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints(
372382 ->config ()
373383 .debug_options ()
374384 .xla_gpu_ensure_minor_dot_contraction_dims ();
385+ const bool pack_along_contracting_dims =
386+ instruction->GetModule ()
387+ ->config ()
388+ .debug_options ()
389+ .xla_gpu_experimental_pack_dot_operands_along_k_dimension ();
375390
376391 const bool is_bf16_to_bf16 =
377392 (output_type == PrimitiveType::BF16 && lhs.type == PrimitiveType::BF16 &&
@@ -388,11 +403,11 @@ absl::Status GpuLayoutAssignment::AddDotBackendConstraints(
388403 is_s8_to_s32 || is_fp8_to_fp8;
389404
390405 for (const Side& side : {lhs, rhs}) {
391- if (both_operands_require_minor_contraction_dims) {
392- TF_RETURN_IF_ERROR ( SetOperandMajorToMinorLayout (
393- instruction, side. operand_no ,
394- /* dim_groups= */
395- { side.batch_dims , side. non_contracting_dims , side. contracting_dims } ));
406+ if (( IsPackedInstruction (side. operand ) && pack_along_contracting_dims) ||
407+ both_operands_require_minor_contraction_dims) {
408+ TF_RETURN_IF_ERROR ( SetDotOperandLayoutToMinorContracting (
409+ instruction, side. operand_no , side. batch_dims , side. contracting_dims ,
410+ side.non_contracting_dims ));
396411 } else if (!side.batch_dims .empty () || side.contracting_dims .size () > 1 ||
397412 side.non_contracting_dims .size () > 1 ) {
398413 TF_RETURN_IF_ERROR (SetDotOperandLayout (
@@ -571,6 +586,42 @@ absl::Status GpuLayoutAssignment::SetDotOperandLayout(
571586 /* dim_groups=*/ {batch_dims, row_dims, col_dims});
572587}
573588
589+ absl::Status GpuLayoutAssignment::SetDotOperandLayoutToMinorContracting (
590+ const HloInstruction* instruction, int64_t operand,
591+ absl::Span<const int64_t > batch_dims,
592+ absl::Span<const int64_t > contracting_dims,
593+ absl::Span<const int64_t > noncontracting_dims) {
594+ Shape shape = instruction->operand (operand)->shape ();
595+
596+ if (shape.has_layout () &&
597+ shape.layout ().minor_to_major_size () >= contracting_dims.size ()) {
598+ // Check that the contracting dimensions are physically minor, i.e. check
599+ // that minor physical dimensions all point to contracting logical
600+ // dimensions.
601+ bool contracting_dims_are_minor = true ;
602+ const auto & minor_to_major = shape.layout ().minor_to_major ();
603+ for (int64_t i = 0 ; i < contracting_dims.size (); ++i) {
604+ if (!absl::c_linear_search (contracting_dims, minor_to_major[i])) {
605+ contracting_dims_are_minor = false ;
606+ break ;
607+ }
608+ }
609+
610+ // If contracting dims are already minor, and the layout is valid, keep it.
611+ if (contracting_dims_are_minor &&
612+ MatrixLayout::For (shape, batch_dims, noncontracting_dims,
613+ contracting_dims)
614+ .ok ()) {
615+ // Re-set the operand layout, so it becomes mandatory.
616+ return SetOperandLayout (shape, instruction, operand);
617+ }
618+ }
619+ return SetOperandMajorToMinorLayout (
620+ instruction, operand,
621+ /* dim_groups=*/
622+ {batch_dims, noncontracting_dims, contracting_dims});
623+ }
624+
574625absl::Status GpuLayoutAssignment::SetOperandMajorToMinorLayout (
575626 const HloInstruction* instruction, int64_t operand,
576627 std::initializer_list<absl::Span<const int64_t >> dim_groups) {
0 commit comments