Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions src/onepass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
e2 = replace_call(e2, p.x, p.t0, x0)
e2 = replace_call(e2, p.x, p.tf, xf)
e2 = subs2(e2, x0, p.x, 0)
e2 = subs(e2, x0, :([$(p.x)[$k, 0] for $k 1:$(p.dim_x)]))
e2 = subs(e2, x0, :([$(p.x)[$k, 0] for $k in 1:($(p.dim_x))]))
e2 = subs2(e2, xf, p.x, :grid_size)
e2 = subs(e2, xf, :([$(p.x)[$k, grid_size] for $k 1:$(p.dim_x)]))
e2 = subs(e2, xf, :([$(p.x)[$k, grid_size] for $k in 1:($(p.dim_x))]))
concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1[1]), ucon=($e3[1])))) # todo: e1/3[1] will be e1/3[k] when vectorised over dim
end
(:initial, rg) => begin
Expand Down Expand Up @@ -801,9 +801,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label)
j = __symgen(:j)
k = __symgen(:k)
e2 = subs2(e2, xt, p.x, j)
e2 = subs(e2, xt, :([$(p.x)[$k, $j] for $k 1:$(p.dim_x)]))
e2 = subs(e2, xt, :([$(p.x)[$k, $j] for $k in 1:($(p.dim_x))]))
e2 = subs2(e2, ut, p.u, j)
e2 = subs(e2, ut, :([$(p.u)[$k, $j] for $k 1:$(p.dim_u)]))
e2 = subs(e2, ut, :([$(p.u)[$k, $j] for $k in 1:($(p.dim_u))]))
e2 = subs(e2, p.t, :($(p.t0) + $j * $(p.dt)))
concat(
code,
Expand Down Expand Up @@ -904,31 +904,34 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e)
j12 = :($j1 + 0.5)
k = __symgen(:k)
ej1 = subs2(e, xt, p.x, j1)
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k 1:$(p.dim_x)]))
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k in 1:($(p.dim_x))]))
ej1 = subs2(ej1, ut, p.u, j1)
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k in 1:($(p.dim_u))]))
ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt)))
ej2 = subs2(e, xt, p.x, j2)
ej2 = subs(ej2, xt, :([$(p.x)[$k, $j2] for $k 1:$(p.dim_x)]))
ej2 = subs(ej2, xt, :([$(p.x)[$k, $j2] for $k in 1:($(p.dim_x))]))
ej2 = subs2(ej2, ut, p.u, j2)
ej2 = subs(ej2, ut, :([$(p.u)[$k, $j2] for $k 1:$(p.dim_u)]))
ej2 = subs(ej2, ut, :([$(p.u)[$k, $j2] for $k in 1:($(p.dim_u))]))
ej2 = subs(ej2, p.t, :($(p.t0) + $j2 * $(p.dt)))
ej12 = subs2m(e, xt, p.x, j1)
ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k ∈ 1:$(p.dim_x)]))
ej12 = subs(
ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k in 1:($(p.dim_x))
])
)
ej12 = subs2(ej12, ut, p.u, j1)
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k in 1:($(p.dim_u))]))
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
dxij = :($(p.x)[$i, $j2] - $(p.x)[$i, $j1])
code = quote
if scheme == :euler
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:grid_size-1)
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1))
elseif scheme ∈ (:euler_implicit, :euler_b) # euler_b is deprecated
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:grid_size-1)
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:(grid_size - 1))
elseif scheme == :midpoint
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:grid_size-1)
$pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1))
elseif scheme ∈ (:trapeze, :trapezoidal) # trapezoidal is deprecated
$pref.constraint(
$p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:grid_size-1
$p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:(grid_size - 1)
)
else
throw(
Expand Down Expand Up @@ -981,25 +984,28 @@ function p_lagrange_exa!(p, p_ocp, e, type)
j12 = :($j1 + 0.5)
k = __symgen(:k)
ej1 = subs2(e, xt, p.x, j1)
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k 1:$(p.dim_x)]))
ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k in 1:($(p.dim_x))]))
ej1 = subs2(ej1, ut, p.u, j1)
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k in 1:($(p.dim_u))]))
ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt)))
ej12 = subs2m(e, xt, p.x, j1)
ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k ∈ 1:$(p.dim_x)]))
ej12 = subs(
ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k in 1:($(p.dim_x))
])
)
ej12 = subs2(ej12, ut, p.u, j1)
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k 1:$(p.dim_u)]))
ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k in 1:($(p.dim_u))]))
ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt)))
code = quote
if scheme == :euler
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:grid_size-1)
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1))
elseif scheme ∈ (:euler_implicit, :euler_b) # euler_b is deprecated
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size)
elseif scheme == :midpoint
$pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:grid_size-1)
$pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1))
elseif scheme ∈ (:trapeze, :trapezoidal) # trapezoidal is deprecated
$pref.objective($p_ocp, $(p.dt) * $ej1 / 2 for $j1 in (0, grid_size))
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size-1)
$pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:(grid_size - 1))
else
throw(
"unknown numerical scheme: $scheme (possible choices are :euler, :euler_implicit, :midpoint, :trapeze)",
Expand Down Expand Up @@ -1049,9 +1055,9 @@ function p_mayer_exa!(p, p_ocp, e, type)
e = replace_call(e, p.x, p.t0, x0)
e = replace_call(e, p.x, p.tf, xf)
e = subs2(e, x0, p.x, 0)
e = subs(e, x0, :([$(p.x)[$k, 0] for $k 1:$(p.dim_x)]))
e = subs(e, x0, :([$(p.x)[$k, 0] for $k in 1:($(p.dim_x))]))
e = subs2(e, xf, p.x, :grid_size)
e = subs(e, xf, :([$(p.x)[$k, grid_size] for $k 1:$(p.dim_x)]))
e = subs(e, xf, :([$(p.x)[$k, grid_size] for $k in 1:($(p.dim_x))]))
# now, x[i](t0) has been replaced by x[i, 0] and x[i](tf) by x[i, grid_size]
code = :($pref.objective($p_ocp, $e))
return __wrap(code, p.lnum, p.line)
Expand Down Expand Up @@ -1369,7 +1375,9 @@ function def_exa(e; log=false)
$(p.box_u) # lvar and uvar for control
$(p.box_v) # lvar and uvar for variable (after x and u for compatibility with CTDirect)
$p_ocp = $pref.ExaCore(
base_type; backend=backend, minimize=($p.criterion == :min) # not $(p.xxxx) as this info is known statically
base_type;
backend=backend,
minimize=($p.criterion == :min), # not $(p.xxxx) as this info is known statically
)
$code
$dyn_check
Expand Down
29 changes: 17 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,19 @@ julia> subs2(subs2(e, :x0, :x, 0), :xf, :x, :N)
:(x0 * (2 * x[3, N]))
```
"""
function subs2(e, x, y, j; k = __symgen(:k))
foo(x, y, j) = (h, args...) -> begin
f = Expr(h, args...)
@match f begin
:($xx[$rg]) && if ((xx == x) && is_range(rg)) end => :([$y[$k, $j] for $k ∈ $rg])
:($xx[$i]) && if (xx == x) end => :($y[$i, $j])
_ => f
function subs2(e, x, y, j; k=__symgen(:k))
foo(x, y, j) =
(h, args...) -> begin
f = Expr(h, args...)
@match f begin
:($xx[$rg]) && if ((xx == x) && is_range(rg))
end => :([$y[$k, $j] for $k in $rg])
:($xx[$i]) && if (xx == x)
end => :($y[$i, $j])
_ => f
end
end
end
expr_it(e, foo(x, y, j), x -> x)
expr_it(e, foo(x, y, j), x -> x)
end

"""
Expand Down Expand Up @@ -195,13 +198,15 @@ julia> subs2m(e, :x0, :x, 0)
:([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:3])
```
"""
function subs2m(e, x, y, j; k = __symgen(:k))
function subs2m(e, x, y, j; k=__symgen(:k))
foo(x, y, j) =
(h, args...) -> begin
f = Expr(h, args...)
@match f begin
:($xx[$rg]) && if ((xx == x) && is_range(rg)) end => :([($y[$k, $j] + $y[$k, $j + 1]) / 2 for $k ∈ $rg])
:($xx[$i]) && if (xx == x) end => :(($y[$i, $j] + $y[$i, $j + 1]) / 2)
:($xx[$rg]) && if ((xx == x) && is_range(rg))
end => :([($y[$k, $j] + $y[$k, $j + 1]) / 2 for $k in $rg])
:($xx[$i]) && if (xx == x)
end => :(($y[$i, $j] + $y[$i, $j + 1]) / 2)
_ => f
end
end
Expand Down
Loading