Skip to content

Commit e8cea72

Browse files
authored
Merge pull request #23 from McMasterU/add_simplify
[ solver ] simplify constraints & obj, add tests
2 parents 73b0f7d + bed9a3b commit e8cea72

File tree

9 files changed

+98
-52
lines changed

9 files changed

+98
-52
lines changed

src/HashedExpression/Codegen/CSimple.hs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,8 @@ instance Codegen CSimpleConfig where
383383
params :: [String]
384384
params = map fst $ paramNodesWithId expressionMap
385385
-- value nodes
386-
vs :: [(String, Int)]
387-
vs = sortOn fst $ varNodesWithId expressionMap ++ paramNodesWithId expressionMap
386+
varsAndParams :: [(String, Int)]
387+
varsAndParams = sortOn fst $ varNodesWithId expressionMap ++ paramNodesWithId expressionMap
388388
-- get shape of a variable
389389
variableShape :: String -> Shape
390390
variableShape name =
@@ -405,7 +405,7 @@ instance Codegen CSimpleConfig where
405405
let isOk (var, nId)
406406
| Just val <- Map.lookup var valMaps = compatible (retrieveShape nId expressionMap) val
407407
| otherwise = True,
408-
Just (var, shape) <- find (not . isOk) vs =
408+
Just (var, shape) <- find (not . isOk) varsAndParams =
409409
Just $ "variable " ++ var ++ "is of shape " ++ show shape ++ " but the value provided is not"
410410
| otherwise = Nothing
411411
-------------------------------------------------------------------------------
@@ -428,8 +428,8 @@ instance Codegen CSimpleConfig where
428428
offset = cAddress nId
429429
shape = retrieveShape nId expressionMap
430430
-------------------------------------------------------------------------------
431-
writeResultCodeEach :: (String, NodeID) -> Code
432-
writeResultCodeEach (name, nId)
431+
writeResultCodeEach :: Variable -> Code
432+
writeResultCodeEach variable
433433
| output == OutputHDF5 =
434434
scoped
435435
[ [i|printf("Writing #{name} to #{name}_out.h5...\\n");|],
@@ -455,6 +455,8 @@ instance Codegen CSimpleConfig where
455455
)
456456
++ ["fclose(file);"]
457457
where
458+
nId = nodeId variable
459+
name = varName variable
458460
offset = cAddress nId
459461
shape = retrieveShape nId expressionMap
460462
-------------------------------------------------------------------------------
@@ -555,12 +557,12 @@ instance Codegen CSimpleConfig where
555557
readValsCodes =
556558
["void read_values() {"]
557559
++ [" srand(time(NULL));"] --
558-
++ scoped (concatMap readValCodeEach vs)
560+
++ scoped (concatMap readValCodeEach varsAndParams)
559561
++ ["}"] --
560562
-------------------------------------------------------------------------------
561563
writeResultCodes =
562564
["void write_result()"]
563-
++ scoped (concatMap writeResultCodeEach vs)
565+
++ scoped (concatMap writeResultCodeEach variables)
564566
-------------------------------------------------------------------------------
565567
evaluatingCodes =
566568
["void evaluate_partial_derivatives_and_objective()"]

src/HashedExpression/Internal/Simplify.hs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
-- Portability : unportable
88
--
99
-- Simplifying expressions
10-
module HashedExpression.Internal.Simplify (simplify) where
10+
module HashedExpression.Internal.Simplify (simplify, simplifyUnwrapped) where
1111

1212
import Control.Monad.State.Strict
1313
import Data.Eq.HT (equating)
@@ -42,8 +42,8 @@ import HashedExpression.Prettify
4242
import Prelude hiding ((^))
4343
import qualified Prelude
4444

45-
simplify :: forall d et. (Dimension d, ElementType et) => Expression d et -> Expression d et
46-
simplify = wrap . removeUnreachable . apply . unwrap
45+
simplifyUnwrapped :: (ExpressionMap, NodeID) -> (ExpressionMap, NodeID)
46+
simplifyUnwrapped = removeUnreachable . apply
4747
where
4848
apply =
4949
multipleTimes 1000 . toRecursiveTransformation . chainModifications $
@@ -73,6 +73,9 @@ simplify = wrap . removeUnreachable . apply . unwrap
7373
]
7474
)
7575

76+
simplify :: forall d et. (Dimension d, ElementType et) => Expression d et -> Expression d et
77+
simplify = wrap . simplifyUnwrapped . unwrap
78+
7679
-- | Predefined holes used for pattern matching with 'Pattern'
7780
[p, q, r, s, t, u, v, w, x, y, z, condition] = map PHole [1 .. 12]
7881

@@ -199,17 +202,15 @@ zeroOneSumProdRules :: Modification
199202
zeroOneSumProdRules exp@(mp, n) =
200203
case retrieveOp n mp of
201204
Sum ns
202-
-- to make sure filter (not . isZero mp) ns is not empty
203-
| all (isZero mp) ns -> just $ head ns
204205
-- if the sumP has any zero, remove them
205206
-- sum(x, y, z, 0, t, 0) = sum(x, y, z, t)
206-
| any (isZero mp) ns -> sum_ . map just . filter (not . isZero mp) $ ns
207+
| (x : _, []) <- partition (isZero mp) ns -> just x
208+
| (_, nonZeros) <- partition (isZero mp) ns -> sum_ . map just $ nonZeros
207209
Mul ns
208-
-- to make sure filter (not . isOne mp) ns is not empty
209-
| all (isOne mp) ns -> just $ head ns
210210
-- if the product has any one, remove them
211211
-- product(x, y, z, 1, t, 1) = product(x, y, z, t)
212-
| any (isOne mp) ns -> product_ . map just . filter (not . isOne mp) $ ns
212+
| (x : _, []) <- partition (isOne mp) ns -> just x
213+
| (_, nonOnes) <- partition (isOne mp) ns -> product_ . map just $ nonOnes
213214
-- if any is zero, collapse to zero
214215
-- product(x, y, z, 0, t, u, v) = 0
215216
| nId : _ <- filter (isZero mp) ns -> just nId

src/HashedExpression/Internal/Utils.hs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module HashedExpression.Internal.Utils where
22

33
import Data.Array
4-
import Data.Complex
4+
import qualified Data.Complex as Complex
5+
import Data.Complex (Complex(..))
56
import qualified Data.IntMap.Strict as IM
67
import Data.List (foldl')
78
import Data.List.Split (splitOn)
@@ -183,3 +184,10 @@ zipMp3 mp1 mp2 mp3 = foldl' f Map.empty $ Map.keys mp1
183184

184185
instance PowerOp Double Int where
185186
(^) x y = x Prelude.^ y
187+
188+
instance ComplexRealOp Double (Complex Double) where
189+
(+:) = (:+)
190+
xRe = Complex.realPart
191+
xIm = Complex.imagPart
192+
conjugate = Complex.conjugate
193+

src/HashedExpression/Interp.hs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ module HashedExpression.Interp
1616
evaluate2DComplex,
1717
evaluate3DReal,
1818
evaluate3DComplex,
19+
fourierTransform1D,
20+
fourierTransform2D,
21+
fourierTransform3D,
22+
FTMode(..),
1923
)
2024
where
2125

@@ -463,8 +467,8 @@ evaluate1DComplex valMap (mp, n)
463467
]
464468
Rotate [amount] arg ->
465469
rotate1D size amount (evaluate1DComplex valMap $ (mp, arg))
466-
FT arg -> fourierTransform1D False size $ evaluate1DComplex valMap (mp, arg)
467-
IFT arg -> fourierTransform1D True size $ evaluate1DComplex valMap (mp, arg)
470+
FT arg -> fourierTransform1D FT_FORWARD size $ evaluate1DComplex valMap (mp, arg)
471+
IFT arg -> fourierTransform1D FT_BACKWARD size $ evaluate1DComplex valMap (mp, arg)
468472
_ -> error "expression structure One C is wrong"
469473
| otherwise = error "one C but shape is not [size] ??"
470474

@@ -600,8 +604,8 @@ evaluate2DComplex valMap (mp, n)
600604
(size1, size2)
601605
(amount1, amount2)
602606
(evaluate2DComplex valMap $ (mp, arg))
603-
FT arg -> fourierTransform2D False (size1, size2) $ evaluate2DComplex valMap (mp, arg)
604-
IFT arg -> fourierTransform2D True (size1, size2) $ evaluate2DComplex valMap (mp, arg)
607+
FT arg -> fourierTransform2D FT_FORWARD (size1, size2) $ evaluate2DComplex valMap (mp, arg)
608+
IFT arg -> fourierTransform2D FT_BACKWARD (size1, size2) $ evaluate2DComplex valMap (mp, arg)
605609
_ -> error "expression structure Two C is wrong"
606610
| otherwise = error "Two C but shape is not [size1, size2] ??"
607611

@@ -739,8 +743,8 @@ evaluate3DComplex valMap (mp, n)
739743
(size1, size2, size3)
740744
(amount1, amount2, amount3)
741745
(evaluate3DComplex valMap $ (mp, arg))
742-
FT arg -> fourierTransform3D False (size1, size2, size3) $ evaluate3DComplex valMap (mp, arg)
743-
IFT arg -> fourierTransform3D True (size1, size2, size3) $ evaluate3DComplex valMap (mp, arg)
746+
FT arg -> fourierTransform3D FT_FORWARD (size1, size2, size3) $ evaluate3DComplex valMap (mp, arg)
747+
IFT arg -> fourierTransform3D FT_BACKWARD (size1, size2, size3) $ evaluate3DComplex valMap (mp, arg)
744748
_ -> error "expression structure Three C is wrong"
745749
| otherwise = error "Three C but shape is not [size1, size2, size3] ??"
746750

@@ -816,6 +820,8 @@ rotate3D (size1, size2, size3) (amount1, amount2, amount3) arr =
816820
k <- [0 .. size3 - 1]
817821
]
818822

823+
data FTMode = FT_FORWARD | FT_BACKWARD deriving (Eq, Ord)
824+
819825
-- | Fourier Transform in 1D.
820826
-- Frequency is just in one dimension.
821827
-- Consider a real-valued function, S(x),
@@ -824,14 +830,14 @@ rotate3D (size1, size2, size3) (amount1, amount2, amount3) arr =
824830
-- length of cycle is P/n, and frequency is n/P.
825831
-- so for input i the frequency is (2*pi*i*n)/P
826832
fourierTransform1D ::
827-
Bool -> Int -> Array Int (Complex Double) -> Array Int (Complex Double)
828-
fourierTransform1D inverse size arr =
833+
FTMode -> Int -> Array Int (Complex Double) -> Array Int (Complex Double)
834+
fourierTransform1D mode size arr =
829835
listArray (0, size - 1) [computeX i | i <- [0 .. size - 1]]
830836
where
831-
s = if inverse then fromIntegral size else 1
837+
s = if mode == FT_BACKWARD then fromIntegral size else 1
832838
computeX i = (sum $ zipWithA (*) arr (fourierBasis i)) / s
833839
fourierBasis i =
834-
let frequency n = (2 * pi * fromIntegral (i * n) / fromIntegral size) * (if inverse then -1 else 1)
840+
let frequency n = (2 * pi * fromIntegral (i * n) / fromIntegral size) * (if mode == FT_BACKWARD then -1 else 1)
835841
in listArray
836842
(0, size - 1)
837843
[ cos (frequency n) :+ (- sin (frequency n))
@@ -847,23 +853,23 @@ fourierTransform1D inverse size arr =
847853
-- so for input i the frequency is (2*pi*i*n)/P
848854
-- the frequency should be calculated in both dimensions for i and j
849855
fourierTransform2D ::
850-
Bool ->
856+
FTMode ->
851857
(Int, Int) ->
852858
Array (Int, Int) (Complex Double) ->
853859
Array (Int, Int) (Complex Double)
854-
fourierTransform2D inverse (size1, size2) arr =
860+
fourierTransform2D mode (size1, size2) arr =
855861
listArray
856862
((0, 0), (size1 - 1, size2 - 1))
857863
[computeX i j | i <- [0 .. size1 - 1], j <- [0 .. size2 - 1]]
858864
where
859-
s = if inverse then fromIntegral (size1 * size2) else 1
865+
s = if mode == FT_BACKWARD then fromIntegral (size1 * size2) else 1
860866
computeX i j = (sum $ zipWithA (*) arr (fourierBasis i j)) / s
861867
fourierBasis i j =
862868
let frequency m n =
863869
( 2 * pi * fromIntegral (i * m) / fromIntegral size1
864870
+ 2 * pi * fromIntegral (j * n) / fromIntegral size2
865871
)
866-
* (if inverse then -1 else 1)
872+
* (if mode == FT_BACKWARD then -1 else 1)
867873
in listArray
868874
((0, 0), (size1 - 1, size2 - 1))
869875
[ cos (frequency m n) :+ (- sin (frequency m n))
@@ -880,11 +886,11 @@ fourierTransform2D inverse (size1, size2) arr =
880886
-- so for input i the frequency is (2*pi*i*n)/P
881887
-- the frequency should be calculated for all dimensions, i , j , k
882888
fourierTransform3D ::
883-
Bool ->
889+
FTMode ->
884890
(Int, Int, Int) ->
885891
Array (Int, Int, Int) (Complex Double) ->
886892
Array (Int, Int, Int) (Complex Double)
887-
fourierTransform3D inverse (size1, size2, size3) arr =
893+
fourierTransform3D mode (size1, size2, size3) arr =
888894
listArray
889895
((0, 0, 0), (size1 - 1, size2 - 1, size3 - 1))
890896
[ computeX i j k
@@ -893,15 +899,15 @@ fourierTransform3D inverse (size1, size2, size3) arr =
893899
k <- [0 .. size3 - 1]
894900
]
895901
where
896-
s = if inverse then fromIntegral (size1 * size2) else 1
902+
s = if mode == FT_BACKWARD then fromIntegral (size1 * size2) else 1
897903
computeX i j k = (sum $ zipWithA (*) arr (fourierBasis i j k)) / s
898904
fourierBasis i j k =
899905
let frequency m n p =
900906
( 2 * pi * fromIntegral (i * m) / fromIntegral size1
901907
+ 2 * pi * fromIntegral (j * n) / fromIntegral size2
902908
+ 2 * pi * fromIntegral (k * p) / fromIntegral size3
903909
)
904-
* (if inverse then -1 else 1)
910+
* (if mode == FT_BACKWARD then -1 else 1)
905911
in listArray
906912
((0, 0, 0), (size1 - 1, size2 - 1, size3 - 1))
907913
[ cos (frequency m n p) :+ (- sin (frequency m n p))

src/HashedExpression/Problem.hs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import HashedExpression.Internal.Expression
3030
import HashedExpression.Internal.Node
3131
import HashedExpression.Internal.OperationSpec
3232
import HashedExpression.Internal.Rewrite
33+
import HashedExpression.Internal.Simplify
3334
import HashedExpression.Internal.Utils
3435
import HashedExpression.Prettify (debugPrint)
3536
import HashedExpression.Value
@@ -195,6 +196,13 @@ getExpressionCS cs =
195196
Upper exp _ -> exp
196197
Between exp _ -> exp
197198

199+
mapExpressionCS :: ((ExpressionMap, NodeID) -> (ExpressionMap, NodeID)) -> ConstraintStatement -> ConstraintStatement
200+
mapExpressionCS f cs =
201+
case cs of
202+
Lower exp v -> Lower (f exp) v
203+
Upper exp v -> Upper (f exp) v
204+
Between exp v -> Between (f exp) v
205+
198206
-- | Extract the value from the 'ConstraintStatement'
199207
getValCS :: ConstraintStatement -> [Val]
200208
getValCS cs =
@@ -264,13 +272,6 @@ mergeToMain (mp, nID) = do
264272
put mergedMp
265273
return mergedNID
266274

267-
mergeToMainMany :: (ExpressionMap, [NodeID]) -> ProblemConstructingM [NodeID]
268-
mergeToMainMany (mp, nIDs) = do
269-
curMp <- get
270-
let (mergedMp, resIDs) = safeMergeManyRoots curMp (mp, nIDs)
271-
put mergedMp
272-
return resIDs
273-
274275
varsWithShape :: (ExpressionMap, NodeID) -> [(String, Shape)]
275276
varsWithShape = mapMaybe collect . IM.toList . fst
276277
where
@@ -298,7 +299,8 @@ constructProblemHelper obj (Constraint constraints) = do
298299
let processF exp = do
299300
let (mp, name2ID) = partialDerivativesMap exp
300301
let (names, beforeMergeIDs) = unzip $ Map.toList name2ID
301-
Map.fromList . zip names <$> mergeToMainMany (mp, beforeMergeIDs)
302+
afterMergedIDs <- mapM (mergeToMain . simplifyUnwrapped . (mp, )) beforeMergeIDs
303+
return $ Map.fromList $ zip names afterMergedIDs
302304
let lookupDerivative :: (String, Shape) -> Map String NodeID -> ProblemConstructingM NodeID
303305
lookupDerivative (name, shape) dMap = case Map.lookup name dMap of
304306
Just dID -> return dID
@@ -377,7 +379,10 @@ constructProblemHelper obj (Constraint constraints) = do
377379

378380
-- | Construct a Problem from given objective function and constraints
379381
constructProblem :: Expression Scalar R -> Constraint -> ProblemResult
380-
constructProblem objectiveFunction constraint =
381-
case runStateT (constructProblemHelper objectiveFunction constraint) IM.empty of
382+
constructProblem objectiveFunction (Constraint cs) =
383+
case runStateT (constructProblemHelper simplifiedObjective simplifiedConstraint) IM.empty of
382384
Left reason -> ProblemInvalid reason
383385
Right (problem, _) -> ProblemValid problem
386+
where
387+
simplifiedObjective = simplify objectiveFunction
388+
simplifiedConstraint = Constraint $ map (mapExpressionCS simplifyUnwrapped) cs

symphony/Symphony/Symphony.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ generateCode outputPath (ValidSymphony objectiveExp vars consts css solver) = do
131131
case generateProblemCode CSimple.CSimpleConfig {output = CSimple.OutputHDF5} heProblem valMap of
132132
Invalid reason -> throwError $ GeneralError reason
133133
Success res -> do
134+
liftIO $ putStrLn "Problem detail:"
135+
liftIO $ print $ heProblem
134136
liftIO $ res outputPath
135137
liftIO $ putStrLn "Download solver & adapter......"
136138
liftIO $ downloadSolver outputPath solver

test/Commons.hs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module Commons where
88
import Control.Applicative (liftA2)
99
import Control.Monad (foldM, forM)
1010
import Data.Array
11-
import Data.Complex
11+
import Data.Complex (Complex(..))
1212
import Data.Function.HT (nest)
1313
import qualified Data.IntMap.Strict as IM
1414
import Data.List (intercalate, sort)
@@ -547,3 +547,9 @@ sz = IM.size . exMap
547547
instance (Ix i, Num a) => Num (Array i a) where
548548
(+) arr1 arr2 = listArray (bounds arr1) $ zipWith (+) (elems arr1) (elems arr2)
549549
(*) arr1 arr2 = listArray (bounds arr1) $ zipWith (*) (elems arr1) (elems arr2)
550+
551+
instance (Ix i) => ComplexRealOp (Array i Double) (Array i (Complex Double)) where
552+
(+:) arr1 arr2 = listArray (bounds arr1) $ zipWith (+:) (elems arr1) (elems arr2)
553+
xRe arr = listArray (bounds arr) $ map xRe (elems arr)
554+
xIm arr = listArray (bounds arr) $ map xIm (elems arr)
555+
conjugate arr = listArray (bounds arr) $ map conjugate (elems arr)

test/SolverSpec.hs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ runCommandIn cwd cmd = do
4848
(_, _, _, ph) <-
4949
createProcess
5050
(shell cmd)
51-
{ cwd = Just cwd
52-
-- std_in = NoStream,
53-
-- std_out = NoStream,
54-
-- std_err = NoStream
51+
{ cwd = Just cwd,
52+
std_in = NoStream,
53+
std_out = NoStream,
54+
std_err = NoStream
5555
}
5656
waitForProcess ph
5757

@@ -127,5 +127,21 @@ spec =
127127
let xGot = getValue2D "x" (5, 5) res
128128
let xExpect = listArray ((0, 0), (4, 4)) $ replicate 25 (1 / 2.7182818285)
129129
xGot `shouldApprox` xExpect
130+
specify "Fourier transform" $ do
131+
let x = variable1D @10 "x"
132+
let y = variable1D @10 "y"
133+
let a = param1D @10 "a"
134+
let b = param1D @10 "b"
135+
let obj = norm2square (ft (x +: y) - (a +: b))
136+
case constructProblem obj (Constraint []) of
137+
ProblemValid p -> do
138+
valA <- listArray (0, 9) <$> generate (vectorOf 10 arbitrary)
139+
valB <- listArray (0, 9) <$> generate (vectorOf 10 arbitrary)
140+
res <- solveProblem p (Map.fromList [("a", V1D valA), ("b", V1D valB)])
141+
let xGot = getValue1D "x" 10 res
142+
let yGot = getValue1D "y" 10 res
143+
let xExpect = xRe (fourierTransform1D FT_BACKWARD 10 (valA +: valB))
144+
let yExpect = xIm (fourierTransform1D FT_BACKWARD 10 (valA +: valB))
145+
xGot `shouldApprox` xExpect
130146
specify "Banana function" $ do
131147
property prop_Rosenbrock

test/Spec.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ main = do
3535
describe "HashedInterpSpec" InterpSpec.spec
3636
describe "StructureSpec" StructureSpec.spec
3737
describe "ReverseDifferentiationSpec" ReverseDifferentiationSpec.spec
38-
-- hspecWith defaultConfig {configQuickCheckMaxSuccess = Just 10} $ do
39-
-- describe "SolverSpec" SolverSpec.spec
38+
hspecWith defaultConfig {configQuickCheckMaxSuccess = Just 10} $ do
39+
describe "SolverSpec" SolverSpec.spec
4040
hspecWith defaultConfig {configQuickCheckMaxSuccess = Just 20} $ do
4141
describe "CSimpleSpec" CSimpleSpec.spec

0 commit comments

Comments
 (0)