File tree Expand file tree Collapse file tree 3 files changed +27
-0
lines changed
jaxlib/mosaic/dialect/tpu Expand file tree Collapse file tree 3 files changed +27
-0
lines changed Original file line number Diff line number Diff line change @@ -93,6 +93,25 @@ LogicalResult MemRefSliceOp::verify() {
9393 auto target_type = getType ();
9494 auto target_layout = target_type.getLayout ();
9595 auto target_memory_space = target_type.getMemorySpace ();
96+ auto indices = getBaseIdx ();
97+ auto slice_shape = getResult ().getType ().getShape ();
98+ if (!source_type.hasStaticShape ()) {
99+ return emitOpError (
100+ " Only slicing of memrefs with static shapes is supported." );
101+ }
102+ auto source_shape = source_type.getShape ();
103+ bool is_semaphore =
104+ HasMemorySpace (source_type, tpu::MemorySpace::kSemaphoreMem );
105+ if (is_semaphore &&
106+ !isa<SemaphoreType, DMASemaphoreType>(source_type.getElementType ())) {
107+ return emitOpError (
108+ " References to semaphore memory space must have a semaphore element "
109+ " type." );
110+ }
111+ if (indices.size () != slice_shape.size () ||
112+ indices.size () != source_shape.size ()) {
113+ return emitOpError (" Indices and slice shapes must match." );
114+ }
96115 // TODO(apaszke): Check that the result has a smaller shape.
97116 // TODO(apaszke): Check that strides are equivalent.
98117 // Source and target attributes may be different before propagation is done by
Original file line number Diff line number Diff line change @@ -141,4 +141,10 @@ bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
141141 return *(tiled_layout.getTileStrides ().end () - 1 ) == 1 &&
142142 *(tiled_layout.getTileStrides ().end () - 2 ) == 1 ;
143143}
144+
145+ bool HasMemorySpace (MemRefType ty, tpu::MemorySpace space) {
146+ auto memory_space =
147+ dyn_cast_or_null<tpu::MemorySpaceAttr>(ty.getMemorySpace ());
148+ return memory_space && memory_space.getValue () == space;
149+ }
144150} // namespace mlir::tpu
Original file line number Diff line number Diff line change @@ -115,6 +115,8 @@ bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
115115 const std::array<int64_t , 2 > &target_shape,
116116 bool allow_minormost_padding = false );
117117
118+ // Determines whether the given MemRefType has the given memory space.
119+ bool HasMemorySpace (MemRefType ty, tpu::MemorySpace space);
118120} // namespace mlir::tpu
119121
120122#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_
You can’t perform that action at this time.
0 commit comments