Skip to content

Commit 3990e05

Browse files
naummoGoogle-ML-Automation
authored andcommitted
[Mosaic] Add extra memref_slice verification and a memory space check helper
PiperOrigin-RevId: 702883469
1 parent 5ade371 commit 3990e05

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,25 @@ LogicalResult MemRefSliceOp::verify() {
9393
auto target_type = getType();
9494
auto target_layout = target_type.getLayout();
9595
auto target_memory_space = target_type.getMemorySpace();
96+
auto indices = getBaseIdx();
97+
auto slice_shape = getResult().getType().getShape();
98+
if (!source_type.hasStaticShape()) {
99+
return emitOpError(
100+
"Only slicing of memrefs with static shapes is supported.");
101+
}
102+
auto source_shape = source_type.getShape();
103+
bool is_semaphore =
104+
HasMemorySpace(source_type, tpu::MemorySpace::kSemaphoreMem);
105+
if (is_semaphore &&
106+
!isa<SemaphoreType, DMASemaphoreType>(source_type.getElementType())) {
107+
return emitOpError(
108+
"References to semaphore memory space must have a semaphore element "
109+
"type.");
110+
}
111+
if (indices.size() != slice_shape.size() ||
112+
indices.size() != source_shape.size()) {
113+
return emitOpError("Indices and slice shapes must match.");
114+
}
96115
// TODO(apaszke): Check that the result has a smaller shape.
97116
// TODO(apaszke): Check that strides are equivalent.
98117
// Source and target attributes may be different before propagation is done by

jaxlib/mosaic/dialect/tpu/util.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,10 @@ bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
141141
return *(tiled_layout.getTileStrides().end() - 1) == 1 &&
142142
*(tiled_layout.getTileStrides().end() - 2) == 1;
143143
}
144+
145+
bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) {
146+
auto memory_space =
147+
dyn_cast_or_null<tpu::MemorySpaceAttr>(ty.getMemorySpace());
148+
return memory_space && memory_space.getValue() == space;
149+
}
144150
} // namespace mlir::tpu

jaxlib/mosaic/dialect/tpu/util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
115115
const std::array<int64_t, 2> &target_shape,
116116
bool allow_minormost_padding = false);
117117

118+
// Determines whether the given MemRefType has the given memory space.
119+
bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space);
118120
} // namespace mlir::tpu
119121

120122
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_

0 commit comments

Comments
 (0)