Skip to content

Commit 978ef7a

Browse files
authored
Merge pull request #1277 from google-research/coalesce-dep-pairs
Coalesce dependent pair contexts in Imp representation.
2 parents 416b07f + 3d8435a commit 978ef7a

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

lib/prelude.dx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2605,7 +2605,7 @@ def (**)(
26052605
y: m=>n=>Float
26062606
) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) =
26072607
-- TODO(https://github.com/google-research/dex-lang/issues/1212) Replace with tiled_matmul.
2608-
naive_matmul(x, y)
2608+
tiled_matmul(x, y)
26092609

26102610
def (**.)(mat: n=>m=>Float, v: m=>Float) -> (n=>Float) given (n|Ix, m|Ix) =
26112611
for i. vdot(mat[i], v)

src/lib/Imp.hs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -648,18 +648,30 @@ getElemTypeAndIdxStructure (LeafType ctxs baseTy) = case ctxs of
648648
Empty -> (UnboxedValue baseTy, Nothing)
649649
Nest b rest -> case b of
650650
TabCtx _ -> error "leading idxs should have been stripped off already"
651-
DepPairCtx depBinder ->
652-
case getIExprInterpretation (LeafType rest baseTy) of
651+
DepPairCtx depBinder -> case splitLeadingDepPairs rest of
652+
PairB depBinders rest' -> case getIExprInterpretation (LeafType rest' baseTy) of
653653
RawValue bt -> (UnboxedValue bt, Nothing)
654654
BufferPtr (BufferType ixs eltTy) -> do
655-
let ixs' = case depBinder of
656-
LeftB _ -> Nothing
657-
RightB UnitB -> Just ixs
655+
let ixs' = case allNothingBs (Nest depBinder depBinders) of
656+
Just UnitB -> Just ixs
657+
Nothing -> Nothing
658658
(BoxedBuffer eltTy, ixs')
659659
RefCtx -> (,Nothing) $ UnboxedValue $ hostPtrTy $ elemTypeToBaseType eltTy
660660
where BufferType _ eltTy = getRefBufferType (LeafType rest baseTy)
661661
where hostPtrTy ty = PtrType (CPU, ty)
662662

663+
allNothingBs :: Nest (MaybeB b) n l -> Maybe (UnitB n l)
664+
allNothingBs Empty = Just UnitB
665+
allNothingBs (Nest (LeftB _) _) = Nothing
666+
allNothingBs (Nest (RightB UnitB) rest) = allNothingBs rest
667+
668+
splitLeadingDepPairs :: TypeCtx SimpIR n l -> PairB (Nest (MaybeB SBinder)) (TypeCtx SimpIR) n l
669+
splitLeadingDepPairs = \case
670+
Empty -> PairB Empty Empty
671+
Nest (DepPairCtx b) rest -> case splitLeadingDepPairs rest of
672+
PairB bs rest' -> PairB (Nest b bs) rest'
673+
ctxs -> PairB Empty ctxs
674+
663675
tryGetBoxIdxStructure :: LeafType n -> Maybe (IndexStructure SimpIR n)
664676
tryGetBoxIdxStructure leafTy = snd $ getElemTypeAndIdxStructure leafTy
665677

@@ -727,14 +739,14 @@ valueToTree (RepVal tyTop valTop) = do
727739
RefTy _ t -> go (RNest ctx RefCtx) t val
728740
DepPairTy (DepPairType _ (b:>t1) (t2)) -> case val of
729741
Branch [v1, v2] -> do
730-
case ctx of
731-
REmpty -> do
742+
case allDepPairCtxs (unRNest ctx) of
743+
Just UnitB -> do
732744
tree1 <- rec t1 v1
733745
x <- repValAtom $ RepVal t1 v1
734746
t2' <- applySubst (b@>SubstVal x) t2
735747
tree2 <- go (RNest ctx (DepPairCtx NothingB )) t2' v2
736748
return $ Branch [tree1, tree2]
737-
_ -> do
749+
Nothing -> do
738750
tree1 <- rec t1 v1
739751
tree2 <- go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2 v2
740752
return $ Branch [tree1, tree2]
@@ -752,6 +764,11 @@ valueToTree (RepVal tyTop valTop) = do
752764
where rec = go ctx
753765
{-# INLINE valueToTree #-}
754766

767+
allDepPairCtxs :: TypeCtx SimpIR n l -> Maybe (UnitB n l)
768+
allDepPairCtxs ctx = case splitLeadingDepPairs ctx of
769+
PairB bs Empty -> allNothingBs bs
770+
_ -> Nothing
771+
755772
storeLeaf :: Emits n => LeafType n -> IExpr n -> IExpr n -> SubstImpM i n ()
756773
storeLeaf leafTy dest src = case getRefBufferType leafTy of
757774
BufferType Singleton (UnboxedValue _) -> store dest src

tests/eval-tests.dx

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,20 @@ case frob of
11181118
> 4.
11191119

11201120

1121+
-- regression tests for #1212
1122+
data Rectangle(a) = AsRectangle(n:Nat, m:Nat, elts:(Fin n => Fin m => a))
1123+
data Brick(a) = AsBrick(n:Nat, m:Nat, l:Nat, elts:(Fin n => Fin m => Fin l => a))
11211124

1125+
def mk_rect(n:Nat) -> Rectangle Nat =
1126+
AsRectangle _ _ $ for i:(Fin n) j:(Fin n). ordinal i
11221127

1128+
def mk_brick(n:Nat) -> Brick Nat =
1129+
AsBrick _ _ _ $ for i:(Fin n) j:(Fin n) k:(Fin n). ordinal i
11231130

1131+
rect = mk_rect(3)
1132+
rect
1133+
> (AsRectangle 3 3 [[0, 0, 0], [1, 1, 1], [2, 2, 2]])
1134+
1135+
brick = mk_brick(3)
1136+
brick
1137+
> (AsBrick 3 3 3 [[[0, 0, 0], [0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2], [2, 2, 2]]])

0 commit comments

Comments
 (0)