Skip to content

Commit afd8069

Browse files
committed
WIP Tree validator
Has support for lists so far
1 parent d704372 commit afd8069

File tree

1 file changed

+133
-48
lines changed

1 file changed

+133
-48
lines changed

src/Codec/CBOR/Cuddle/CBOR/Validator.hs

Lines changed: 133 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import Data.Bits hiding (And)
2020
import Data.ByteString qualified as BS
2121
import Data.ByteString.Lazy qualified as BSL
2222
import Data.Function ((&))
23-
import Data.Functor ((<&>))
23+
import Data.Functor ((<&>), ($>))
2424
import Data.Functor.Identity
2525
import Data.IntSet qualified as IS
2626
import Data.List.NonEmpty qualified as NE
@@ -60,9 +60,9 @@ data CDDLResult
6060
-- | Rule we are trying
6161
Rule
6262
-- | List of expansions of rules
63-
[[Rule]]
63+
ExpansionTree
6464
-- | For each expansion, for each of the rules in the expansion, the result
65-
[[(Rule, CBORTermResult)]]
65+
(ExpansionTree' [(Rule, CBORTermResult)])
6666
| -- | All expansions failed
6767
--
6868
-- An expansion is: Given a CBOR @TMap@ of @N@ elements, we will expand the
@@ -71,7 +71,7 @@ data CDDLResult
7171
-- | Rule we are trying
7272
Rule
7373
-- | List of expansions
74-
[[Rule]]
74+
ExpansionTree
7575
-- | A list of matched items @(key, value, rule)@ and the unmatched item
7676
[([AMatchedItem], ANonMatchedItem)]
7777
| -- | The rule was valid but the control failed
@@ -655,58 +655,132 @@ flattenGroup cddl nodes =
655655
| rule <- nodes
656656
]
657657

658+
-- | A filter on a subtree in an expansion. How this is used will depend on the
659+
-- contenxt in which this expansion is used. For maps, we filter based on the
660+
-- key, which can be in any position. For arrays, we filter based on the first
661+
-- value.
662+
data Filter
663+
= NoFilter
664+
| Filter {mapFilter :: Rule, arrayFilter :: Rule}
665+
deriving Show
666+
667+
-- | A tree of possible expansions of a rule matching the size of a container to
668+
-- validate. This tree contains filters at each node, such that we can
669+
-- short-circuit the branch.
670+
--
671+
-- Note that, for simplicity's sake, the gates do not actually consume tokens,
672+
-- so once we reach a leaf we must match it entire against the input.
673+
--
674+
-- The leaves of an expansion tree may be of different lengths until we merge
675+
-- them.
676+
data ExpansionTree' r
677+
= -- | A leaf represents the full sequence of rules which must be matched
678+
Leaf r
679+
| -- | Multiple possibilities for matching
680+
Branch [ExpansionTree' r]
681+
| -- | Set of possibilities guarded by a filter
682+
FilterBranch Filter (ExpansionTree' r)
683+
deriving (Functor, Show)
684+
685+
-- | Merge trees
686+
--
687+
-- We merge from the left, folding a copy of the second tree into each interior
688+
-- node in the first.
689+
mergeTrees :: [ExpansionTree] -> ExpansionTree
690+
mergeTrees [] = Branch []
691+
mergeTrees (a : as) = foldl' go a as
692+
where
693+
go (Leaf xs) b = prependRules xs b
694+
go (Branch xs) b = Branch $ fmap (flip go b) xs
695+
go (FilterBranch f x) b = FilterBranch f $ go x b
696+
697+
-- | Clamp a tree to contain only expressions with a fixed number of elements.
698+
clampTree :: Int -> ExpansionTree -> ExpansionTree
699+
clampTree sz a = maybe (Branch []) id (go a)
700+
where
701+
go l@(Leaf x) = if length x == sz then Just l else Nothing
702+
go (Branch xs) = case catMaybes (go <$> xs) of
703+
[] -> Nothing
704+
ys -> Just $ Branch ys
705+
go (FilterBranch f x) = FilterBranch f <$> go x
706+
707+
type ExpansionTree = ExpansionTree' [Rule]
708+
709+
prependRule :: Rule -> ExpansionTree -> ExpansionTree
710+
prependRule r t = (r :) <$> t
711+
712+
-- | Prepend the given rules atop each leaf node in the tree
713+
prependRules :: [Rule] -> ExpansionTree -> ExpansionTree
714+
prependRules rs t = (rs <>) <$> t
715+
716+
filterOn :: Rule -> Reader CDDL Filter
717+
filterOn rule =
718+
getRule rule >>= \case
719+
KV k v _ -> pure $ Filter k v
720+
_ -> pure NoFilter
721+
658722
-- | Expand rules to reach exactly the wanted length, which must be the number
659723
-- of items in the container. For example, if we want to validate 3 elements,
660724
-- and we have the following CDDL:
661725
--
662726
-- > a = [* int, * bool]
663727
--
664-
-- this will be expanded to `[int, int, int], [int, int, bool], [int, bool,
665-
-- bool], [bool, bool, bool]`.
728+
-- this will be expanded to
729+
-- ```
730+
-- [int, int, bool]
731+
-- int
732+
-- [int, int, int]
733+
-- int
734+
-- bool
735+
-- [int, bool, bool]
736+
-- *
737+
-- bool
738+
-- [bool, bool, bool]
739+
--
740+
-- ```
666741
--
667742
-- Essentially the rules we will parse is the choice among the expansions of the
668743
-- original rules.
669-
expandRules :: Int -> [Rule] -> Reader CDDL [[Rule]]
744+
expandRules :: Int -> [Rule] -> Reader CDDL ExpansionTree
670745
expandRules remainingLen []
671-
| remainingLen /= 0 = pure []
672-
expandRules _ [] = pure [[]]
746+
| remainingLen /= 0 = pure $ Branch []
747+
expandRules _ [] = pure $ Branch []
673748
expandRules remainingLen _
674-
| remainingLen < 0 = pure []
675-
| remainingLen == 0 = pure [[]]
676-
expandRules remainingLen (x : xs) = do
677-
y <- expandRule remainingLen x
678-
concat
679-
<$> mapM
680-
( \y' -> do
681-
suffixes <- expandRules (remainingLen - length y') xs
682-
pure [y' ++ ys' | ys' <- suffixes]
683-
)
684-
y
749+
| remainingLen < 0 = pure $ Branch []
750+
| remainingLen == 0 = pure $ Branch []
751+
expandRules remainingLen xs = do
752+
ys <- traverse (expandRule remainingLen) xs
753+
pure . clampTree remainingLen $ mergeTrees ys
685754

686-
expandRule :: Int -> Rule -> Reader CDDL [[Rule]]
755+
expandRule :: Int -> Rule -> Reader CDDL ExpansionTree
687756
expandRule maxLen _
688-
| maxLen < 0 = pure []
757+
| maxLen < 0 = pure $ Branch []
689758
expandRule maxLen rule =
690759
getRule rule >>= \case
691-
Occur o OIOptional -> pure $ [] : [[o] | maxLen > 0]
692-
Occur o OIZeroOrMore -> ([] :) <$> expandRule maxLen (MIt (Occur o OIOneOrMore))
760+
-- For an optional branch, there is no point including a separate filter
761+
Occur o OIOptional -> pure $ Branch [Leaf [o] | maxLen > 0]
762+
Occur o OIZeroOrMore -> do
763+
f <- filterOn o
764+
FilterBranch f <$> expandRule maxLen (MIt (Occur o OIOneOrMore))
693765
Occur o OIOneOrMore ->
694766
if maxLen > 0
695-
then ([o] :) . map (o :) <$> expandRule (maxLen - 1) (MIt (Occur o OIOneOrMore))
696-
else pure []
767+
then do
768+
f <- filterOn o
769+
FilterBranch f . prependRule o <$> expandRule (maxLen - 1) (MIt (Occur o OIOneOrMore))
770+
else pure $ Branch []
697771
Occur o (OIBounded low high) -> case (low, high) of
698772
(Nothing, Nothing) -> expandRule maxLen (MIt (Occur o OIZeroOrMore))
699773
(Just (fromIntegral -> low'), Nothing) ->
700774
if maxLen >= low'
701-
then map (replicate low' o ++) <$> expandRule (maxLen - low') (MIt (Occur o OIZeroOrMore))
702-
else pure []
775+
then (prependRules $ replicate low' o) <$> expandRule (maxLen - low') (MIt (Occur o OIZeroOrMore))
776+
else pure $ Branch []
703777
(Nothing, Just (fromIntegral -> high')) ->
704-
pure [replicate n o | n <- [0 .. min maxLen high']]
778+
pure $ Branch [Leaf $ replicate n o | n <- [0 .. min maxLen high']]
705779
(Just (fromIntegral -> low'), Just (fromIntegral -> high')) ->
706780
if maxLen >= low'
707-
then pure [replicate n o | n <- [low' .. min maxLen high']]
708-
else pure []
709-
_ -> pure [[rule | maxLen > 0]]
781+
then pure $ Branch [Leaf $ replicate n o | n <- [low' .. min maxLen high']]
782+
else pure $ Branch []
783+
_ -> pure $ Branch [Leaf [rule] | maxLen > 0]
710784

711785
-- | Which rules are optional?
712786
isOptional :: MonadReader CDDL m => Rule -> m Bool
@@ -725,9 +799,9 @@ isOptional rule =
725799
validateListWithExpandedRules ::
726800
forall m.
727801
MonadReader CDDL m =>
728-
[Term] -> [Rule] -> m [(Rule, CBORTermResult)]
802+
NE.NonEmpty Term -> [Rule] -> m [(Rule, CBORTermResult)]
729803
validateListWithExpandedRules terms rules =
730-
go (zip terms rules)
804+
go (zip (NE.toList terms) rules)
731805
where
732806
go ::
733807
[(Term, Rule)] -> m [(Rule, CBORTermResult)]
@@ -751,26 +825,37 @@ validateListWithExpandedRules terms rules =
751825
validateExpandedList ::
752826
forall m.
753827
MonadReader CDDL m =>
754-
[Term] ->
755-
[[Rule]] ->
828+
NE.NonEmpty Term ->
829+
ExpansionTree ->
756830
m (Rule -> CDDLResult)
757831
validateExpandedList terms rules = go rules
758832
where
759-
go :: [[Rule]] -> m (Rule -> CDDLResult)
760-
go [] = pure $ \r -> ListExpansionFail r rules []
761-
go (choice : choices) = do
833+
go :: ExpansionTree -> m (Rule -> CDDLResult)
834+
go (Leaf choice) = do
762835
res <- validateListWithExpandedRules terms choice
763836
case res of
764837
[] -> pure Valid
765838
_ -> case last res of
766839
(_, CBORTermResult _ (Valid _)) -> pure Valid
767-
_ ->
768-
go choices
769-
>>= ( \case
770-
Valid _ -> pure Valid
771-
ListExpansionFail _ _ errors -> pure $ \r -> ListExpansionFail r rules (res : errors)
772-
)
773-
. ($ dummyRule)
840+
_ -> pure $ \r -> ListExpansionFail r rules (Leaf res)
841+
go (FilterBranch f x) = validateTerm (NE.head terms) (arrayFilter f) >>= \case
842+
(CBORTermResult _ (Valid _)) -> go x
843+
-- In this case we insert a leaf since we haven't actually validated the
844+
-- subnodes.
845+
err -> pure $ \r -> ListExpansionFail r rules $ FilterBranch f $ Leaf [(r, err)]
846+
go (Branch xs) = goBranch xs
847+
848+
goBranch [] = pure $ \r -> ListExpansionFail r rules $ Branch []
849+
goBranch (x:xs) = go x <&> ($ dummyRule) >>= \case
850+
Valid _ -> pure Valid
851+
ListExpansionFail _ _ errors -> prependBranchErrors errors <$> goBranch xs
852+
853+
prependBranchErrors errors res = case res dummyRule of
854+
Valid _ -> Valid
855+
ListExpansionFail _ _ errors2 -> \r ->
856+
ListExpansionFail r rules $ errors <> errors2
857+
858+
774859

775860
validateList ::
776861
MonadReader CDDL m => [Term] -> Rule -> m CDDLResult
@@ -781,11 +866,11 @@ validateList terms rule =
781866
Array rules ->
782867
case terms of
783868
[] -> ifM (and <$> mapM isOptional rules) (pure Valid) (pure InvalidRule)
784-
_ ->
869+
t:ts ->
785870
ask >>= \cddl ->
786871
let sequencesOfRules =
787872
runReader (expandRules (length terms) $ flattenGroup cddl rules) cddl
788-
in validateExpandedList terms sequencesOfRules
873+
in validateExpandedList (t NE.:| ts) sequencesOfRules
789874
Choice opts -> validateChoice (validateList terms) opts
790875
_ -> pure UnapplicableRule
791876

0 commit comments

Comments
 (0)