Skip to content

Commit e72b449

Browse files
naummoGoogle-ML-Automation
authored andcommitted
Reverts c04aec9
PiperOrigin-RevId: 698654038
1 parent 6568713 commit e72b449

File tree

3 files changed

+24
-61
lines changed

3 files changed

+24
-61
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -654,15 +654,14 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> {
654654
I32:$amount,
655655
Optional<I32>:$device_id, // For remote DMAs
656656
Optional<I32>:$core_id, // For megacore
657-
Optional<I32>:$subcore_id, // For the SC vector subcore
658657
OptionalAttr<TPU_CoreTypeEnum>:$core_type
659658
);
660659
let assemblyFormat = [{
661-
$semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`subcore_id` $subcore_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore)
660+
$semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? (`core_type` $core_type^)? attr-dict `:` type($semaphore)
662661
}];
663662
let hasVerifier = 1;
664663
let builders = [
665-
// A backward-compatible builder that sets `subcore_id` and `core_type` to nullptr.
664+
// A backward-compatible builder that sets `core_type` to nullptr.
666665
OpBuilder<(ins "Value":$semaphore, "Value":$amount,
667666
"Value":$device_id, "Value":$core_id)>,
668667
];

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ void SemaphoreSignalOp::build(OpBuilder &builder, OperationState &state,
844844
Value semaphore, Value amount, Value device_id,
845845
Value core_id) {
846846
build(builder, state, semaphore, amount, device_id, core_id,
847-
/*subcore_id=*/nullptr, /*core_type=*/nullptr);
847+
/*core_type=*/nullptr);
848848
}
849849

850850
LogicalResult SemaphoreSignalOp::verify() {
@@ -861,39 +861,21 @@ LogicalResult SemaphoreSignalOp::verify() {
861861
CoreType issuing_core_type = issuing_core_type_maybe->value_or(CoreType::kTc);
862862
CoreType target_core_type = getCoreType().value_or(issuing_core_type);
863863

864-
if (getCoreId() == nullptr && getDeviceId() == nullptr &&
865-
getSubcoreId() == nullptr) {
864+
if (getCoreId() == nullptr && getDeviceId() == nullptr) {
866865
if (target_core_type != issuing_core_type) {
867-
return emitOpError(absl::StrFormat(
868-
"Target core type (%s) must match source core type "
869-
"(%s) when device_id, core_id and subcore_id are not specified",
870-
stringifyCoreType(target_core_type),
871-
stringifyCoreType(issuing_core_type)));
866+
return emitOpError(
867+
absl::StrFormat("Target core type (%s) must match source core type "
868+
"(%s) when device_id and core_id are not specified",
869+
stringifyCoreType(target_core_type),
870+
stringifyCoreType(issuing_core_type)));
872871
}
873872
}
874-
if (target_core_type == CoreType::kScVectorSubcore &&
875-
issuing_core_type != CoreType::kScVectorSubcore &&
876-
getSubcoreId() == nullptr) {
877-
return emitOpError(
878-
"Subcore ID must be specified for the SC vector subcore");
879-
}
880-
if (target_core_type != CoreType::kScVectorSubcore &&
881-
getSubcoreId() != nullptr) {
882-
return emitOpError(
883-
"Subcore ID must be specified only for the SC vector subcore");
884-
}
885873
if ((issuing_core_type == CoreType::kTc &&
886-
(target_core_type == CoreType::kScScalarSubcore ||
887-
target_core_type == CoreType::kScVectorSubcore)) ||
888-
((issuing_core_type == CoreType::kScScalarSubcore ||
889-
issuing_core_type == CoreType::kScVectorSubcore) &&
874+
target_core_type == CoreType::kScScalarSubcore) ||
875+
(issuing_core_type == CoreType::kScScalarSubcore &&
890876
target_core_type == CoreType::kTc)) {
891877
return emitOpError("Signalling between TC and SC is not implemented");
892878
}
893-
if (target_core_type == CoreType::kScVectorSubcore &&
894-
(getCoreId() != nullptr || getDeviceId() != nullptr)) {
895-
return emitOpError("Signalling remote SC vector subcores is not supported");
896-
}
897879
return success();
898880
}
899881

jaxlib/mosaic/dialect/tpu/transforms/serde.cc

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,19 @@ limitations under the License.
1515

1616
// We need to keep some extra headers for the code in tpu_passes.h.inc.
1717

18-
#include <cstdint>
1918
#include <memory> // IWYU pragma: keep
2019
#include <optional>
2120
#include <string>
2221
#include <string_view>
2322

24-
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2523
#include "mlir/IR/BuiltinAttributes.h"
2624
#include "mlir/IR/BuiltinOps.h"
25+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2726
#include "mlir/IR/OperationSupport.h"
2827
#include "mlir/IR/Value.h"
2928
#include "mlir/IR/Visitors.h"
3029
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
3130
#include "mlir/Support/LLVM.h"
32-
#include "absl/strings/str_format.h"
3331
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
3432
#include "mlir/include/mlir/IR/OpDefinition.h"
3533
#include "mlir/include/mlir/IR/OperationSupport.h"
@@ -45,7 +43,7 @@ namespace {
4543

4644
constexpr std::string_view kMangledDialect = "stable_mosaic.";
4745
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
48-
constexpr int kVersion = 4;
46+
constexpr int kVersion = 3;
4947

5048
StringRef mangle(StringRef name, std::string* storage) {
5149
storage->clear();
@@ -88,37 +86,21 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) {
8886

8987
LogicalResult semaphore_signal_rule(Operation* op, int version) {
9088
// Added AttrSizedOperandSegments and core_id in version 2.
91-
// Added subcore_id in version 4.
9289
if (version < 2) {
9390
if (op->getNumOperands() == 2) { // Local signal.
94-
op->setAttr(
95-
OpTrait::AttrSizedOperandSegments<
96-
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
97-
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0, 0}));
91+
op->setAttr(OpTrait::AttrSizedOperandSegments<
92+
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
93+
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0}));
9894
} else if (op->getNumOperands() == 3) { // Remote signal.
99-
op->setAttr(
100-
OpTrait::AttrSizedOperandSegments<
101-
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
102-
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0, 0}));
103-
}
104-
return op->emitError("Unexpected operand count in tpu.semaphore_signal");
105-
} else if (version < 4) {
106-
ArrayRef<int32_t> operand_segment_sizes =
107-
op->getAttrOfType<DenseI32ArrayAttr>(
108-
OpTrait::AttrSizedOperandSegments<
109-
SemaphoreSignalOp>::getOperandSegmentSizeAttr());
110-
if (operand_segment_sizes.size() != 4) {
111-
return op->emitError(absl::StrFormat(
112-
"Expected operand count to be 4 in tpu.semaphore_signal. Got %d",
113-
operand_segment_sizes.size()));
95+
// Hardcoding that one optional value is device_id, not core_id. This
96+
// could misinterpret sem_signals where core_id is specified, but
97+
// device_id isn't.
98+
op->setAttr(OpTrait::AttrSizedOperandSegments<
99+
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
100+
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0}));
101+
} else {
102+
return op->emitError("Unexpected operand count in tpu.semaphore_signal");
114103
}
115-
SmallVector<int32_t, 5> new_operand_segment_sizes(
116-
operand_segment_sizes.begin(), operand_segment_sizes.end());
117-
new_operand_segment_sizes.push_back(0);
118-
op->setAttr(OpTrait::AttrSizedOperandSegments<
119-
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
120-
mlir::DenseI32ArrayAttr::get(op->getContext(),
121-
new_operand_segment_sizes));
122104
}
123105
return success();
124106
}

0 commit comments

Comments
 (0)