@@ -872,9 +872,30 @@ def _ceildiv(p: int, q: int) -> int:
872872# ---
873873
874874
875+ def _parse_stride (st ):
876+ if isinstance (st , int ):
877+ stride = st
878+ start = [0 ]
879+ elif isinstance (st , tuple ):
880+ stride , start = st [1 ], list (st [0 ])
881+
882+ return stride , start
883+
875884def _parse_rma (target , source , size = None , tst = 1 , sst = 1 ):
876- tdata , tlen , ttype = _getbuffer (target , readonly = False )
877- sdata , slen , stype = _getbuffer (source , readonly = True )
885+ if isinstance (tst , tuple ): assert target .ndim == len (tst [0 ])
886+ if isinstance (sst , tuple ): assert source .ndim == len (sst [0 ])
887+ tst , tstart = _parse_stride (tst )
888+ sst , sstart = _parse_stride (sst )
889+
890+ if tstart != [0 ]:
891+ tdata , tlen , ttype = _getbuffer (target [* tstart [:- 1 ],tstart [- 1 ]:], readonly = False )
892+ else :
893+ tdata , tlen , ttype = _getbuffer (target , readonly = False )
894+
895+ if sstart != [0 ]:
896+ sdata , slen , stype = _getbuffer (source [* sstart [:- 1 ],sstart [- 1 ]:], readonly = True )
897+ else :
898+ sdata , slen , stype = _getbuffer (source , readonly = True )
878899
879900 assert ttype == stype
880901 ctype = ttype
@@ -884,8 +905,7 @@ def _parse_rma(target, source, size=None, tst=1, sst=1):
884905 if size is None :
885906 size = min (tsize , ssize )
886907 else :
887- assert size <= tsize
888- assert size <= ssize
908+ assert size >= 0
889909
890910 return (ctype , tdata , sdata , size )
891911
@@ -901,6 +921,8 @@ def _shmem_rma(ctx, name, target, source, size, pe):
901921
902922def _shmem_irma (ctx , name , target , source , tst , sst , size , pe ):
903923 ctype , target , source , size = _parse_rma (target , source , size , tst , sst )
924+ tst , _ = _parse_stride (tst )
925+ sst , _ = _parse_stride (sst )
904926 return _shmem (ctx , ctype , f'i{ name } ' )(target , source , tst , sst , size , pe )
905927
906928
0 commit comments