@@ -289,7 +289,7 @@ def _dout(
289289 elif operator in ['D1N2' ]:
290290 kmat = [f'tMM{ ii } ' for ii in range (nnd )]
291291 elif operator in ['D2N2' ]:
292- lcomb = [] if nd == '1d' else [(0 ,1 )]
292+ lcomb = [] if nd == '1d' else [(0 , 1 )]
293293 kmat = (
294294 [f'tMM{ ii } { ii } ' for ii in range (nnd )]
295295 + [f'tMM{ ii } { jj } ' for ii , jj in lcomb ]
@@ -434,7 +434,7 @@ def _units(u0=None, operator=None, geometry=None):
434434 if geometry == 'linear' :
435435 units = u0
436436 else :
437- units = u0 ** 2
437+ units = u0 ** 2 / asunits . Unit ( 'rad' )
438438
439439 elif operator == 'D0N2' :
440440 if geometry == 'linear' :
@@ -533,20 +533,12 @@ def apply_operator(
533533
534534 if operator == 'D0N1' :
535535 ind = [- 1 ]
536- if cropbs is None :
537- for k0 in key :
538- ddata [k0 ]['data' ][...] = np .tensordot (
539- integ_op ['M' ]['data' ],
540- coll .ddata [k0 ]['data' ],
541- (ind , daxis [k0 ]['axis' ]),
542- )
543- else :
544- for k0 in key :
545- ddata [k0 ]['data' ][...] = np .tensordot (
546- integ_op ['M' ]['data' ],
547- coll .ddata [k0 ]['data' ][daxis [k0 ]['slice' ]],
548- (ind , daxis [k0 ]['axis' ]),
549- )
536+ for k0 in key :
537+ ddata [k0 ]['data' ][...] = np .tensordot (
538+ integ_op ['M' ]['data' ],
539+ coll .ddata [k0 ]['data' ][daxis [k0 ]['slice' ]],
540+ (ind , daxis [k0 ]['axis' ][0 ]),
541+ )
550542 else :
551543 raise NotImplementedError ()
552544
@@ -759,7 +751,7 @@ def _apply_operator_prepare(
759751 else :
760752 raise NotImplementedError ()
761753
762- # populate
754+ # populate
763755 ddata [k0 ] = {
764756 'data' : np .full (shape , np .nan ),
765757 'ref' : ref ,
@@ -768,18 +760,18 @@ def _apply_operator_prepare(
768760
769761 # slicing
770762 if cropbs is None :
771- sli = None
772- else :
773- sli = tuple ([
774- cropbs if ii == axis [0 ]
775- else slice (None )
776- for ii in range (len (shape0 ))
777- if ii not in axisf [1 :]
778- ])
763+ shcrop = tuple ([ shape0 [ ii ] for ii in axis ])
764+ cropbs = np . ones ( shcrop , dtype = bool )
765+ sli = tuple ([
766+ cropbs if ii == axis [0 ]
767+ else slice (None )
768+ for ii in range (len (shape0 ))
769+ if ii not in axisf [1 :]
770+ ])
779771
780772 daxis [k0 ] = {
781773 'slice' : sli ,
782774 'axis' : axis ,
783775 }
784776
785- return ddata , daxis , cropbs
777+ return ddata , daxis , cropbs
0 commit comments