Skip to content

Commit e44fd92

Browse files
committed
added map,conv and permute
map - maps points based on equations conv - allows running convolution filter permute - switches dimension position
1 parent ba752c4 commit e44fd92

File tree

11 files changed

+2334
-1005
lines changed

11 files changed

+2334
-1005
lines changed

more_math/Parser/FloatEvalVisitor.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,18 @@ def visitTanhFunc(self, ctx): return math.tanh(self.visit(ctx.expr()))
104104
def visitAsinhFunc(self, ctx): return math.asinh(self.visit(ctx.expr()))
105105
def visitAcoshFunc(self, ctx): return math.acosh(self.visit(ctx.expr()))
106106
def visitAtanhFunc(self, ctx): return math.atanh(self.visit(ctx.expr()))
107-
def visitAbsFunc(self, ctx): return math.abs(self.visit(ctx.expr()))
108-
def visitAbsExp(self, ctx): return math.abs(self.visit(ctx.expr()))
107+
def visitAbsFunc(self, ctx): return abs(self.visit(ctx.expr()))
108+
def visitAbsExp(self, ctx): return abs(self.visit(ctx.expr()))
109+
def visitListExp(self, ctx): return self.visit(ctx.expr(0))
109110
def visitSqrtFunc(self, ctx): return math.sqrt(self.visit(ctx.expr()))
110111
def visitLnFunc(self, ctx): return math.log(self.visit(ctx.expr()))
111112
def visitLogFunc(self, ctx): return math.log10(self.visit(ctx.expr()))
112113
def visitExpFunc(self, ctx): return math.exp(self.visit(ctx.expr()))
113-
def visitNormFunc(self, ctx): return math.sqrt(math.avg(x**2 for x in self.visit(ctx.expr())))
114+
def visitNormFunc(self, ctx):
115+
vals = self.visit(ctx.expr())
116+
if isinstance(vals, (list, tuple)):
117+
return math.sqrt(sum(x**2 for x in vals) / len(vals))
118+
return abs(vals)
114119
def visitFloorFunc(self, ctx): return math.floor(self.visit(ctx.expr()))
115120
def visitFractFunc(self, ctx):
116121
val = self.visit(ctx.expr())
@@ -131,7 +136,7 @@ def visitSignFunc(self, ctx):
131136
x = self.visit(ctx.expr())
132137
return math.copysign(1.0, x) if x != 0 else 0.0
133138
def visitCeilFunc(self, ctx): return math.ceil(self.visit(ctx.expr()))
134-
def visitRoundFunc(self, ctx): return math.round(self.visit(ctx.expr()))
139+
def visitRoundFunc(self, ctx): return round(self.visit(ctx.expr()))
135140
def visitGammaFunc(self, ctx): return math.gamma(self.visit(ctx.expr())).exp()
136141
def visitPrintFunc(self, ctx):
137142
val = self.visit(ctx.expr())
@@ -151,10 +156,10 @@ def visitStepFunc(self, ctx):
151156
# N-argument functions
152157
def visitSMinFunc(self, ctx):
153158
args = [self.visit(e) for e in ctx.expr()]
154-
return math.min(args)
159+
return min(args)
155160
def visitSMaxFunc(self, ctx):
156161
args = [self.visit(e) for e in ctx.expr()]
157-
return math.max(args)
162+
return max(args)
158163

159164
def visitClampFunc(self, ctx):
160165
x = self.visit(ctx.expr(0))
@@ -188,6 +193,13 @@ def visitFunc3Exp(self, ctx):
188193
return self.visitChildren(ctx)
189194
def visitFunc4Exp(self, ctx):
190195
return self.visitChildren(ctx)
196+
197+
def visitMapFunc(self, ctx):
198+
return self.visit(ctx.expr(0))
199+
200+
def visitConvFunc(self, ctx):
201+
return self.visit(ctx.expr(0))
202+
191203
def visitFuncNExp(self, ctx):
192204
return self.visitChildren(ctx)
193205
def visitAtomExp(self, ctx):

more_math/Parser/MathExpr.g4

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ atom
5151
| CONSTANT # ConstantExp
5252
| '(' expr ')' # ParenExp
5353
| PIPE expr PIPE # AbsExp
54+
| '[' expr (',' expr)* ']' # ListExp
5455
;
5556

5657
// Single-argument functions
@@ -88,7 +89,6 @@ func1
8889
| SOFTPLUS '(' expr ')' # SoftplusFunc
8990
| GELU '(' expr ')' # GeluFunc
9091
| SIGN '(' expr ')' # SignFunc
91-
9292
;
9393

9494
// Two-argument functions
@@ -99,6 +99,7 @@ func2
9999
| TMAX '(' expr ',' expr ')' # TMaxFunc
100100
| STEP '(' expr ',' expr ')' # StepFunc
101101
;
102+
102103
func3
103104
: CLAMP '(' expr ',' expr ',' expr ')' # ClampFunc
104105
| LERP '(' expr ',' expr ',' expr ')' # LerpFunc
@@ -108,11 +109,14 @@ func3
108109
func4
109110
: SWAP '(' expr ',' expr ',' expr ',' expr ')' # SwapFunc
110111
;
111-
// N-argument functions (at least 2 arguments)
112-
funcN
113-
: SMIN '(' expr (',' expr)+ ')' # SMinFunc
114-
| SMAX '(' expr (',' expr)+ ')' # SMaxFunc
115112

113+
// N-argument functions
114+
funcN
115+
: SMIN '(' expr (',' expr)* ')' # SMinFunc
116+
| SMAX '(' expr (',' expr)* ')' # SMaxFunc
117+
| MAP '(' expr (',' expr)+ ')' # MapFunc
118+
| CONV '(' expr (',' expr)+ ')' # ConvFunc
119+
| PERM '(' expr ',' expr ')' # PermuteFunc
116120
;
117121

118122
// LEXER RULES
@@ -160,7 +164,10 @@ RELU : 'relu';
160164
SOFTPLUS : 'softplus';
161165
GELU : 'gelu';
162166
SIGN : 'sign';
167+
MAP : 'map';
168+
CONV : 'conv';
163169
SWAP : 'swap';
170+
PERM : 'permute';
164171

165172
PLUS : '+';
166173
MINUS : '-';

more_math/Parser/MathExpr.interp

Lines changed: 11 additions & 1 deletion
Large diffs are not rendered by default.

more_math/Parser/MathExpr.tokens

Lines changed: 129 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,124 +1,134 @@
11
T__0=1
22
T__1=2
33
T__2=3
4-
SIN=4
5-
COS=5
6-
TAN=6
7-
ASIN=7
8-
ACOS=8
9-
ATAN=9
10-
ATAN2=10
11-
SINH=11
12-
COSH=12
13-
TANH=13
14-
ASINH=14
15-
ACOSH=15
16-
ATANH=16
17-
ABS=17
18-
SQRT=18
19-
LN=19
20-
LOG=20
21-
EXP=21
22-
SMIN=22
23-
SMAX=23
24-
TMIN=24
25-
TMAX=25
26-
TNORM=26
27-
SNORM=27
28-
FLOOR=28
29-
CEIL=29
30-
ROUND=30
31-
GAMMA=31
32-
POWE=32
33-
SIGM=33
34-
CLAMP=34
35-
SFFT=35
36-
SIFFT=36
37-
ANGL=37
38-
PRNT=38
39-
LERP=39
40-
STEP=40
41-
SMOOTHSTEP=41
42-
FRACT=42
43-
RELU=43
44-
SOFTPLUS=44
45-
GELU=45
46-
SIGN=46
47-
SWAP=47
48-
PLUS=48
49-
MINUS=49
50-
MULT=50
51-
DIV=51
52-
MOD=52
53-
POW=53
54-
GE=54
55-
GT=55
56-
LE=56
57-
LT=57
58-
EQ=58
59-
NE=59
60-
PIPE=60
61-
CONSTANT=61
62-
NUMBER=62
63-
VARIABLE=63
64-
WS=64
4+
T__3=4
5+
T__4=5
6+
SIN=6
7+
COS=7
8+
TAN=8
9+
ASIN=9
10+
ACOS=10
11+
ATAN=11
12+
ATAN2=12
13+
SINH=13
14+
COSH=14
15+
TANH=15
16+
ASINH=16
17+
ACOSH=17
18+
ATANH=18
19+
ABS=19
20+
SQRT=20
21+
LN=21
22+
LOG=22
23+
EXP=23
24+
SMIN=24
25+
SMAX=25
26+
TMIN=26
27+
TMAX=27
28+
TNORM=28
29+
SNORM=29
30+
FLOOR=30
31+
CEIL=31
32+
ROUND=32
33+
GAMMA=33
34+
POWE=34
35+
SIGM=35
36+
CLAMP=36
37+
SFFT=37
38+
SIFFT=38
39+
ANGL=39
40+
PRNT=40
41+
LERP=41
42+
STEP=42
43+
SMOOTHSTEP=43
44+
FRACT=44
45+
RELU=45
46+
SOFTPLUS=46
47+
GELU=47
48+
SIGN=48
49+
MAP=49
50+
CONV=50
51+
SWAP=51
52+
PERM=52
53+
PLUS=53
54+
MINUS=54
55+
MULT=55
56+
DIV=56
57+
MOD=57
58+
POW=58
59+
GE=59
60+
GT=60
61+
LE=61
62+
LT=62
63+
EQ=63
64+
NE=64
65+
PIPE=65
66+
CONSTANT=66
67+
NUMBER=67
68+
VARIABLE=68
69+
WS=69
6570
'('=1
6671
')'=2
67-
','=3
68-
'sin'=4
69-
'cos'=5
70-
'tan'=6
71-
'asin'=7
72-
'acos'=8
73-
'atan'=9
74-
'atan2'=10
75-
'sinh'=11
76-
'cosh'=12
77-
'tanh'=13
78-
'asinh'=14
79-
'acosh'=15
80-
'atanh'=16
81-
'abs'=17
82-
'sqrt'=18
83-
'ln'=19
84-
'log'=20
85-
'exp'=21
86-
'smin'=22
87-
'smax'=23
88-
'tmin'=24
89-
'tmax'=25
90-
'tnorm'=26
91-
'snorm'=27
92-
'floor'=28
93-
'ceil'=29
94-
'round'=30
95-
'gamma'=31
96-
'pow'=32
97-
'sigm'=33
98-
'clamp'=34
99-
'fft'=35
100-
'ifft'=36
101-
'angle'=37
102-
'print'=38
103-
'lerp'=39
104-
'step'=40
105-
'smoothstep'=41
106-
'fract'=42
107-
'relu'=43
108-
'softplus'=44
109-
'gelu'=45
110-
'sign'=46
111-
'swap'=47
112-
'+'=48
113-
'-'=49
114-
'*'=50
115-
'/'=51
116-
'%'=52
117-
'^'=53
118-
'>='=54
119-
'>'=55
120-
'<='=56
121-
'<'=57
122-
'=='=58
123-
'!='=59
124-
'|'=60
72+
'['=3
73+
','=4
74+
']'=5
75+
'sin'=6
76+
'cos'=7
77+
'tan'=8
78+
'asin'=9
79+
'acos'=10
80+
'atan'=11
81+
'atan2'=12
82+
'sinh'=13
83+
'cosh'=14
84+
'tanh'=15
85+
'asinh'=16
86+
'acosh'=17
87+
'atanh'=18
88+
'abs'=19
89+
'sqrt'=20
90+
'ln'=21
91+
'log'=22
92+
'exp'=23
93+
'smin'=24
94+
'smax'=25
95+
'tmin'=26
96+
'tmax'=27
97+
'tnorm'=28
98+
'snorm'=29
99+
'floor'=30
100+
'ceil'=31
101+
'round'=32
102+
'gamma'=33
103+
'pow'=34
104+
'sigm'=35
105+
'clamp'=36
106+
'fft'=37
107+
'ifft'=38
108+
'angle'=39
109+
'print'=40
110+
'lerp'=41
111+
'step'=42
112+
'smoothstep'=43
113+
'fract'=44
114+
'relu'=45
115+
'softplus'=46
116+
'gelu'=47
117+
'sign'=48
118+
'map'=49
119+
'conv'=50
120+
'swap'=51
121+
'permute'=52
122+
'+'=53
123+
'-'=54
124+
'*'=55
125+
'/'=56
126+
'%'=57
127+
'^'=58
128+
'>='=59
129+
'>'=60
130+
'<='=61
131+
'<'=62
132+
'=='=63
133+
'!='=64
134+
'|'=65

0 commit comments

Comments
 (0)