Skip to content

Commit c941403

Browse files
authored
Merge pull request #1 from arhik/main
`decl` expressions and `type` expressions
2 parents 63fb6ee + 4a078be commit c941403

File tree

7 files changed

+61
-10
lines changed

7 files changed

+61
-10
lines changed

src/WGPUTranspiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ include("codegen/rangeBlock.jl")
1212
include("codegen/conditionBlock.jl")
1313
include("codegen/funcBlock.jl")
1414
include("codegen/builtin.jl")
15-
include("codegen/computeBlock.jl")
1615
include("codegen/expr.jl")
16+
include("codegen/computeBlock.jl")
1717
include("codegen/infer.jl")
1818
include("codegen/resolve.jl")
1919
include("codegen/transpile.jl")

src/codegen/assignment.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ end
4040
symbol(assign::AssignmentExpr) = symbol(assign.lhs)
4141

4242
function assignExpr(scope, lhs, rhs)
43-
lhsVar = LHS(inferVariable(scope, lhs), false)
43+
lhsVar = LHS(inferExpr(scope, lhs), false)
4444
inferScope!(scope, lhsVar)
4545
rhsExpr = RHS(inferExpr(scope, rhs))
4646
statement = AssignmentExpr(lhsVar, rhsExpr, scope)

src/codegen/computeBlock.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ end
2323

2424
struct ComputeBlock <: JLBlock
2525
fname::WGPUVariable
26-
fargs::Vector{WGPUVariable}
26+
fargs::Vector{DeclExpr}
2727
Targs::Vector{WGPUVariable}
2828
fbody::Vector{JLExpr}
2929
scope::Union{Nothing, Scope}
@@ -35,7 +35,7 @@ function computeBlock(scope, islaunch, wgSize, wgCount, fname, fargs)
3535
@capture(fexpr, function fname_(fargs__) where Targs__ fbody__ end)
3636
childScope = Scope([Targs...], [:ceil], 0, scope, quote end)
3737
fn = inferExpr(childScope, fname)
38-
fa = map(x -> inferVariable(childScope, x), fargs)
38+
fa = map(x -> inferExpr(childScope, x), fargs)
3939
fb = map(x -> inferExpr(childScope, x), fbody)
4040
return ComputeBlock(fn, fa, WGPUVariable[], fb, childScope)
4141
end

src/codegen/expr.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,22 @@ end
102102
symbol(access::AccessExpr) = symbol(access.sym)
103103

104104
function inferScope!(scope::Scope, jlexpr::AccessExpr)
105+
#inferScope!(scope, jlexpr.sym)
106+
#inferScope!(scope, jlexpr.field)
107+
end
108+
109+
struct TypeExpr <: JLExpr
110+
sym::WGPUVariable
111+
types::Vector{WGPUVariable}
112+
end
113+
114+
symbol(tExpr::TypeExpr) = (symbol(tExpr.sym), map(x -> symbol(x), tExpr.types)...)
105115

116+
struct DeclExpr <: JLExpr
117+
sym::WGPUVariable
118+
dataType::Union{DataType, TypeExpr}
106119
end
120+
121+
symbol(decl::DeclExpr) = symbol(decl.sym)
122+
123+

src/codegen/infer.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,14 @@ function inferExpr(scope::Scope, expr::Expr)
2929
return binaryOp(scope, :-=, a, b)
3030
elseif @capture(expr, f_(args__))
3131
return callExpr(scope, f, args)
32+
elseif @capture(expr, a_::b_)
33+
return declExpr(scope, a, b)
3234
elseif @capture(expr, a_[b_])
3335
return indexExpr(scope, a, b)
3436
elseif @capture(expr, a_.b_)
3537
return accessExpr(scope, a, b)
38+
elseif @capture(expr, a_{b__})
39+
return typeExpr(scope, a, b)
3640
elseif @capture(expr, for idx_ in range_ block__ end)
3741
return rangeBlock(scope, idx, range, block)
3842
elseif @capture(expr, if cond_ block__ end)
@@ -52,15 +56,37 @@ function inferExpr(scope::Scope, expr::Expr)
5256
end
5357
end
5458

59+
function declExpr(scope, a::Val{:hello})
60+
@error "Not implemented yet"
61+
end
62+
63+
function declExpr(scope, a::Symbol, b::Symbol)
64+
aExpr = inferExpr(scope, a)
65+
inferScope!(scope, aExpr)
66+
bExpr = eval(b)
67+
return DeclExpr(aExpr, bExpr)
68+
end
69+
70+
function typeExpr(scope, a::Symbol, b::Vector{Any})
71+
aExpr = inferExpr(scope, a)
72+
bExpr = map(x -> inferExpr(scope, x), b)
73+
return TypeExpr(aExpr, bExpr)
74+
end
75+
76+
function declExpr(scope, a::Symbol, b::Expr)
77+
bExpr = inferExpr(scope, b)
78+
aExpr = inferExpr(scope, a)
79+
return DeclExpr(aExpr, bExpr)
80+
end
81+
82+
5583
function inferVariable(scope, expr::Expr)
5684
if @capture(expr, a_::b_{t__})
5785
push!(scope.locals, a)
5886
return WGPUVariable(a, eval(b), Generic, nothing, ) # TODO t is ignored
5987
elseif @capture(expr, a_::b_)
6088
push!(scope.locals, a)
6189
return WGPUVariable(a, eval(b), Generic, nothing, )
62-
elseif @capture(expr, a_[b_])
63-
return indexExpr(scope, a, b)
6490
else
6591
error("This expression $expr type is not captured yet")
6692
end

src/codegen/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ end
8989
a = WgpuArray(rand(Float32, 4, 4));
9090
b = WgpuArray(rand(Float32, 4, 4));
9191

92-
scope = Scope([], [], 0, nothing, quote end)
92+
scope = Scope([:out, :x], [], 0, nothing, quote end)
9393
inferredExpr = inferExpr(
9494
scope,
9595
:(@wgpukernel launch=true workgroupSize=(4, 4) workgroupCount=(1, 1) $cast_kernel($a, $b))

src/codegen/transpile.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ export transpile
22

33
transpile(scope::Scope, s::Scalar) = s.element
44
transpile(scope::Scope, var::WGPUVariable) = var.dataType == Any ? :($(var.sym)) : :($(var.sym)::$(var.dataType))
5-
transpile(scope::Scope, lhs::LHS) = transpile(scope, lhs.variable, Val(lhs.mutable))
6-
transpile(scope::Scope, var::WGPUVariable, ::Val{false}) = :(@var $(transpile(scope, var)))
7-
transpile(scope::Scope, var::WGPUVariable, ::Val{true}) = :($(var.sym))
5+
transpile(scope::Scope, lhs::LHS) = transpile(scope, lhs.variable)
6+
transpile(scope::Scope, var::WGPUVariable, ::Val{true}) = :(@var $(transpile(scope, var)))
7+
transpile(scope::Scope, var::WGPUVariable, ::Val{false}) = :($(var.sym))
88
transpile(scope::Scope, rhs::RHS) = transpile(scope, rhs.rhsExpr)
99
transpile(scope::Scope, binOp::BinaryOp) = transpile(scope, binOp, Val(binOp.op))
1010

@@ -46,6 +46,14 @@ function transpile(scope::Scope, acsExpr::AccessExpr)
4646
return Expr(:., transpile(scope, acsExpr.sym), QuoteNode(transpile(scope, acsExpr.field)))
4747
end
4848

49+
transpile(scope::Scope, declExpr::DeclExpr) = Expr(:(::), map(x -> transpile(scope, x), (declExpr.sym, declExpr.dataType))...)
50+
transpile(scope::Scope, ::Type{T}) where T = :($T)
51+
52+
transpile(scope::Scope, typeExpr::TypeExpr) = Expr(
53+
:curly, transpile(scope, typeExpr.sym),
54+
map(x -> transpile(scope, x), typeExpr.types)...
55+
)
56+
4957
function transpile(scope::Scope, rblock::RangeBlock)
5058
(start, step, stop) = map(x -> transpile(scope, x), (rblock.start, rblock.step, rblock.stop))
5159
range = :($start:$step:$stop)

0 commit comments

Comments
 (0)