@@ -648,18 +648,30 @@ getElemTypeAndIdxStructure (LeafType ctxs baseTy) = case ctxs of
648648 Empty -> (UnboxedValue baseTy, Nothing )
649649 Nest b rest -> case b of
650650 TabCtx _ -> error " leading idxs should have been stripped off already"
651- DepPairCtx depBinder ->
652- case getIExprInterpretation (LeafType rest baseTy) of
651+ DepPairCtx depBinder -> case splitLeadingDepPairs rest of
652+ PairB depBinders rest' -> case getIExprInterpretation (LeafType rest' baseTy) of
653653 RawValue bt -> (UnboxedValue bt, Nothing )
654654 BufferPtr (BufferType ixs eltTy) -> do
655- let ixs' = case depBinder of
656- LeftB _ -> Nothing
657- RightB UnitB -> Just ixs
655+ let ixs' = case allNothingBs ( Nest depBinder depBinders) of
656+ Just UnitB -> Just ixs
657+ Nothing -> Nothing
658658 (BoxedBuffer eltTy, ixs')
659659 RefCtx -> (,Nothing ) $ UnboxedValue $ hostPtrTy $ elemTypeToBaseType eltTy
660660 where BufferType _ eltTy = getRefBufferType (LeafType rest baseTy)
661661 where hostPtrTy ty = PtrType (CPU , ty)
662662
663+ allNothingBs :: Nest (MaybeB b ) n l -> Maybe (UnitB n l )
664+ allNothingBs Empty = Just UnitB
665+ allNothingBs (Nest (LeftB _) _) = Nothing
666+ allNothingBs (Nest (RightB UnitB ) rest) = allNothingBs rest
667+
668+ splitLeadingDepPairs :: TypeCtx SimpIR n l -> PairB (Nest (MaybeB SBinder )) (TypeCtx SimpIR ) n l
669+ splitLeadingDepPairs = \ case
670+ Empty -> PairB Empty Empty
671+ Nest (DepPairCtx b) rest -> case splitLeadingDepPairs rest of
672+ PairB bs rest' -> PairB (Nest b bs) rest'
673+ ctxs -> PairB Empty ctxs
674+
663675tryGetBoxIdxStructure :: LeafType n -> Maybe (IndexStructure SimpIR n )
664676tryGetBoxIdxStructure leafTy = snd $ getElemTypeAndIdxStructure leafTy
665677
@@ -727,14 +739,14 @@ valueToTree (RepVal tyTop valTop) = do
727739 RefTy _ t -> go (RNest ctx RefCtx ) t val
728740 DepPairTy (DepPairType _ (b:> t1) (t2)) -> case val of
729741 Branch [v1, v2] -> do
730- case ctx of
731- REmpty -> do
742+ case allDepPairCtxs (unRNest ctx) of
743+ Just UnitB -> do
732744 tree1 <- rec t1 v1
733745 x <- repValAtom $ RepVal t1 v1
734746 t2' <- applySubst (b@> SubstVal x) t2
735747 tree2 <- go (RNest ctx (DepPairCtx NothingB )) t2' v2
736748 return $ Branch [tree1, tree2]
737- _ -> do
749+ Nothing -> do
738750 tree1 <- rec t1 v1
739751 tree2 <- go (RNest ctx (DepPairCtx (JustB (b:> t1)))) t2 v2
740752 return $ Branch [tree1, tree2]
@@ -752,6 +764,11 @@ valueToTree (RepVal tyTop valTop) = do
752764 where rec = go ctx
753765{-# INLINE valueToTree #-}
754766
767+ allDepPairCtxs :: TypeCtx SimpIR n l -> Maybe (UnitB n l )
768+ allDepPairCtxs ctx = case splitLeadingDepPairs ctx of
769+ PairB bs Empty -> allNothingBs bs
770+ _ -> Nothing
771+
755772storeLeaf :: Emits n => LeafType n -> IExpr n -> IExpr n -> SubstImpM i n ()
756773storeLeaf leafTy dest src = case getRefBufferType leafTy of
757774 BufferType Singleton (UnboxedValue _) -> store dest src
0 commit comments