Skip to content

Commit dc0b774

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Allow null parts for tpu.pack_subelements, meaning "don't care"
PiperOrigin-RevId: 707439259
1 parent 3262770 commit dc0b774

File tree

3 files changed

+137
-39
lines changed

3 files changed

+137
-39
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,21 @@ def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
371371
}
372372

373373
// Integer packs are always signed at the moment.
374-
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure]> {
374+
def TPU_PackSubelementsOp : TPU_Op<"pack_subelements", [Pure, SameTypeOperands]> {
375375
let arguments = (ins
376-
Variadic<AnyVectorOfNonZeroRank>:$sources,
376+
Variadic<TPU_Vreg>:$sources,
377+
DenseI32ArrayAttr:$positions,
377378
TPU_PackFormatEnum:$pack_format
378379
);
379-
let results = (outs AnyVectorOfNonZeroRank:$output);
380+
let results = (outs TPU_Vreg:$output);
380381
let assemblyFormat = [{ $sources attr-dict `:` type($sources) `->` type($output) }];
382+
let builders = [
383+
OpBuilder<(ins "::mlir::VectorType":$output_type, "::mlir::ArrayRef<::mlir::Value>":$padded_sources, "::mlir::tpu::PackFormat":$pack_format)>,
384+
];
385+
let extraClassDeclaration = [{
386+
static ::mlir::SmallVector<::mlir::Value> getPaddedSources(::mlir::ValueRange sources, ::mlir::ArrayRef<int32_t> positions, int packing_factor);
387+
}];
388+
let hasVerifier = 1;
381389
}
382390

383391
def TPU_GatherOp : TPU_Op<"gather", [Pure]> {

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <cstddef>
1617
#include <cstdint>
1718
#include <optional>
1819
#include <string_view>
@@ -1113,6 +1114,54 @@ LogicalResult WeirdOp::verify() {
11131114
return success();
11141115
}
11151116

1117+
void PackSubelementsOp::build(OpBuilder &builder, OperationState &state,
1118+
const VectorType output_type,
1119+
const ArrayRef<Value> padded_sources,
1120+
const PackFormat pack_format) {
1121+
SmallVector<Value> sources;
1122+
SmallVector<int32_t> positions;
1123+
for (size_t i = 0; i < padded_sources.size(); ++i) {
1124+
if (padded_sources[i] != nullptr) {
1125+
sources.push_back(padded_sources[i]);
1126+
positions.push_back(i);
1127+
}
1128+
}
1129+
build(builder, state, output_type, sources, positions, pack_format);
1130+
}
1131+
1132+
SmallVector<Value> PackSubelementsOp::getPaddedSources(
1133+
ValueRange sources, const ArrayRef<int32_t> positions,
1134+
const int packing_factor) {
1135+
SmallVector<Value> padded_sources(packing_factor);
1136+
for (const auto [source, position] : llvm::zip(sources, positions)) {
1137+
padded_sources[position] = source;
1138+
}
1139+
return padded_sources;
1140+
}
1141+
1142+
LogicalResult PackSubelementsOp::verify() {
1143+
if (getSources().empty()) {
1144+
return emitOpError("At least one source is required");
1145+
}
1146+
if (getPositions().size() != getSources().size()) {
1147+
return emitOpError("Size of sources and positions must match");
1148+
}
1149+
const int packing_factor = cast<VectorType>(getSources().front().getType())
1150+
.getElementTypeBitWidth() /
1151+
getType().getElementTypeBitWidth();
1152+
SmallVector<bool> seen_positions(packing_factor, false);
1153+
for (const int32_t position : getPositions()) {
1154+
if (position < 0 || packing_factor <= position) {
1155+
return emitOpError("Positions must be between 0 and the packing factor");
1156+
}
1157+
if (seen_positions[position]) {
1158+
return emitOpError("Positions must be unique");
1159+
}
1160+
seen_positions[position] = true;
1161+
}
1162+
return success();
1163+
}
1164+
11161165
} // namespace tpu
11171166
} // namespace mlir
11181167

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 77 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,12 +1035,19 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
10351035
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
10361036
SmallVector<Value> parts;
10371037
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
1038-
idxs_local.back() *= packing;
1039-
for (int64_t i = 0; i < packing; ++i) {
1040-
parts.push_back(input_vregs(idxs_local));
1041-
// Pack any data lying around if OOB
1042-
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
1043-
++idxs_local.back();
1038+
if (!layout_out.offsets()[1].has_value()) {
1039+
idxs_local.back() = 0;
1040+
// Make sure we set all parts of the output vreg to make it replicated
1041+
parts.append(packing, input_vregs(idxs_local));
1042+
} else {
1043+
idxs_local.back() *= packing;
1044+
for (int64_t i = 0; i < packing; ++i) {
1045+
if (idxs_local.back() < input_vregs.dimensions().back()) {
1046+
parts.push_back(input_vregs(idxs_local));
1047+
++idxs_local.back();
1048+
} else {
1049+
parts.push_back(nullptr);
1050+
}
10441051
}
10451052
}
10461053
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
@@ -1053,16 +1060,19 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
10531060
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
10541061
CHECK_GE(idxs.size(), 2);
10551062
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
1056-
idxs_local[idxs.size() - 2] *= packing;
1057-
parts.push_back(input_vregs(idxs_local));
1058-
idxs_local[idxs.size() - 2]++;
1059-
while (parts.size() < packing) {
1060-
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
1061-
parts.push_back(input_vregs(idxs_local));
1062-
idxs_local[idxs.size() - 2]++;
1063-
} else {
1064-
// Once we run out of tiles, we can pick any one we like.
1065-
parts.push_back(parts.back());
1063+
if (!layout_out.offsets()[0].has_value()) {
1064+
*(idxs_local.end() - 2) = 0;
1065+
// Make sure we set all parts of the output vreg to make it replicated
1066+
parts.append(packing, input_vregs(idxs_local));
1067+
} else {
1068+
*(idxs_local.end() - 2) *= packing;
1069+
for (int64_t i = 0; i < packing; ++i) {
1070+
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
1071+
parts.push_back(input_vregs(idxs_local));
1072+
++*(idxs_local.end() - 2);
1073+
} else {
1074+
parts.push_back(nullptr);
1075+
}
10661076
}
10671077
}
10681078
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts,
@@ -6253,6 +6263,11 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
62536263
ctx.target_shape[1]}) {
62546264
// Note: for int4, retiling with scratch is always faster.
62556265
if (bitwidth != 4 || !has_enough_scratch) {
6266+
// Note: The code below does not work when src is replicated and dst is
6267+
// not, since it relies on the src vreg array shape to know how many tiles
6268+
// to pack in dst, and vreg array shapes with materialized offsets are
6269+
// unfortunately not equal to vreg array shapes with replicated offsets.
6270+
CHECK(dst.offsets() == src_offsets);
62566271
xla::Array<Value> retiled(dst_tiles_shape);
62576272
VectorType vreg_x32 =
62586273
vty.getElementType().isSignlessInteger()
@@ -6263,19 +6278,29 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
62636278
SmallVector<Value, 8> parts;
62646279
parts.reserve(packing);
62656280
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
6266-
src_idx[src_idx.size() - 2] *= packing;
6267-
src_idx[src_idx.size() - 1] /= packing;
6268-
for (int i = 0; i < packing; ++i) {
6269-
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
6270-
loc, vreg_x32, vregs(src_idx), vreg_part,
6271-
tpu::PackFormat::kCompressed));
6272-
if (src_idx[src_idx.size() - 2] <
6273-
vregs.dim(vregs.num_dimensions() - 2) - 1) {
6274-
++src_idx[src_idx.size() - 2];
6281+
*(src_idx.end() - 1) /= packing;
6282+
if (!dst.offsets()[0].has_value()) {
6283+
*(src_idx.end() - 2) = 0;
6284+
// Make sure we set all parts of the output vreg to make it replicated
6285+
parts.append(packing, builder.create<tpu::UnpackSubelementsOp>(
6286+
loc, vreg_x32, vregs(src_idx), vreg_part,
6287+
tpu::PackFormat::kCompressed));
6288+
} else {
6289+
*(src_idx.end() - 2) *= packing;
6290+
for (int i = 0; i < packing; ++i) {
6291+
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
6292+
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
6293+
loc, vreg_x32, vregs(src_idx), vreg_part,
6294+
tpu::PackFormat::kCompressed));
6295+
++*(src_idx.end() - 2);
6296+
} else {
6297+
parts.push_back(nullptr);
6298+
}
62756299
}
62766300
}
62776301
*tile = builder.create<tpu::PackSubelementsOp>(
6278-
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed);
6302+
loc, cast<VectorType>(vregs.begin()->getType()), parts,
6303+
tpu::PackFormat::kCompressed);
62796304
});
62806305
return std::pair(dst, std::move(retiled));
62816306
}
@@ -6334,6 +6359,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
63346359
// [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
63356360
// moving to the next one. This is exactly an interleaving of the sublanes
63366361
// of the vreg parts.
6362+
6363+
// Note: The code below does not work when src is replicated and dst is
6364+
// not, since it relies on the src vreg array shape to know how many tiles
6365+
// to pack in dst, and vreg array shapes with materialized offsets are
6366+
// unfortunately not equal to vreg array shapes with replicated offsets.
6367+
CHECK(dst.offsets() == src.offsets());
63376368
xla::Array<Value> retiled(dst_tiles_shape);
63386369
const VectorType vreg_x32 =
63396370
vty.getElementType().isSignlessInteger()
@@ -6343,20 +6374,30 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
63436374
SmallVector<Value> parts;
63446375
parts.reserve(packing);
63456376
SmallVector<int64_t> src_idx(toArrayRef(idx));
6346-
*(src_idx.end() - 2) *= packing;
63476377
const int64_t vreg_part = *(src_idx.end() - 1) % packing;
63486378
*(src_idx.end() - 1) /= packing;
6349-
for (int i = 0; i < packing; ++i) {
6350-
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
6351-
loc, vreg_x32, vregs(src_idx), vreg_part,
6352-
tpu::PackFormat::kCompressed));
6353-
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) {
6354-
++*(src_idx.end() - 2);
6355-
} // The rest is padding, so just pick any of the input parts (but not
6356-
// an arbitrary vreg so we don't add an extra dependency).
6379+
if (!dst.offsets()[0].has_value()) {
6380+
*(src_idx.end() - 2) = 0;
6381+
// Make sure we set all parts of the output vreg to make it replicated
6382+
parts.append(packing, builder.create<tpu::UnpackSubelementsOp>(
6383+
loc, vreg_x32, vregs(src_idx), vreg_part,
6384+
tpu::PackFormat::kCompressed));
6385+
} else {
6386+
*(src_idx.end() - 2) *= packing;
6387+
for (int i = 0; i < packing; ++i) {
6388+
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2)) {
6389+
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
6390+
loc, vreg_x32, vregs(src_idx), vreg_part,
6391+
tpu::PackFormat::kCompressed));
6392+
++*(src_idx.end() - 2);
6393+
} else {
6394+
parts.push_back(nullptr);
6395+
}
6396+
}
63576397
}
63586398
*tile = builder.create<tpu::PackSubelementsOp>(
6359-
loc, vregs.begin()->getType(), parts, tpu::PackFormat::kInterleaved);
6399+
loc, cast<VectorType>(vregs.begin()->getType()), parts,
6400+
tpu::PackFormat::kInterleaved);
63606401
});
63616402
return std::pair(dst, std::move(retiled));
63626403
}

0 commit comments

Comments
 (0)