diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0acc147..9a88f0c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -11,17 +11,12 @@ jobs: # Job for the GitHub hosted runners (ubuntu, windows, macos) test-github: uses: control-toolbox/CTActions/.github/workflows/ci.yml@main - with: - versions: '["1.10", "1.11", "1.12"]' - runs_on: '["ubuntu-latest", "windows-latest"]' - archs: '["x64"]' - runner_type: 'github' # Job for the self-hosted runner moonshot (GPU/CUDA) test-moonshot: uses: control-toolbox/CTActions/.github/workflows/ci.yml@main with: versions: '["1"]' - runs_on: '[["self-hosted", "Linux", "gpu", "cuda", "cuda12"]]' + runs_on: '["moonshot"]' archs: '["x64"]' runner_type: 'self-hosted' diff --git a/.gitignore b/.gitignore index 3ca71c5..d306fa6 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,7 @@ Manifest.toml # Local reports (analysis, status reports, previews) should not be tracked reports/ + +# claude +CLAUDE.local.md +.claude/ diff --git a/Project.toml b/Project.toml index 3ca7759..352f416 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "CTParser" uuid = "32681960-a1b1-40db-9bff-a1ca817385d1" -version = "0.8.0" +version = "0.8.1" authors = ["Jean-Baptiste Caillau "] [deps] CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" @@ -14,6 +15,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [compat] CTBase = "0.17" DocStringExtensions = "0.9" +ExaModels = "0.9" MLStyle = "0.4" OrderedCollections = "1" Parameters = "0.12" diff --git a/src/CTParser.jl b/src/CTParser.jl index 8cd8c95..b6d54a7 100644 --- a/src/CTParser.jl +++ b/src/CTParser.jl @@ -24,5 +24,6 @@ include("defaults.jl") include("utils.jl") include("onepass.jl") include("initial_guess.jl") +#include("exa_linalg.jl") debug end diff --git a/src/exa_linalg.jl b/src/exa_linalg.jl new file mode 100644 index 0000000..152c985 --- /dev/null +++ b/src/exa_linalg.jl @@ -0,0 +1,725 @@ +# ============================================================================= +# exa_linalg.jl - Symbolic wrappers for ExaModels.AbstractNode +# +# This file defines wrapper types that completely eliminate type piracy: +# - SymNumber: wraps scalar AbstractNode values +# - SymVector, SymMatrix: wrap arrays of AbstractNode +# +# Key design principle: NO type piracy on AbstractNode +# All operations dispatch on our wrapper types, never on AbstractNode directly +# ============================================================================= + +using ExaModels: AbstractNode, Null +using LinearAlgebra +using LinearAlgebra: norm_sqr +using Base.Broadcast: Broadcasted, DefaultArrayStyle +using SparseArrays: AbstractSparseVector, AbstractSparseMatrixCSC, AbstractCompressedVector + +# ============================================================================= +# EXPORTS +# ============================================================================= +export sym_add, sym_mul, norm_sqr +export SymVector, SymMatrix, SymArray, AbstractSymArray +export SymNumber, unwrap_scalar + +# ============================================================================= +# 1. TYPE DEFINITIONS +# ============================================================================= + +# Scalar wrapper - wraps a single AbstractNode value +struct SymNumber{T<:AbstractNode} + value::T +end + +# Abstract supertype for all symbolic array wrappers +abstract type AbstractSymArray{N} <: AbstractArray{SymNumber, N} end + +# Symbolic vector wrapper (now stores SymNumber elements) +struct SymVector <: AbstractSymArray{1} + data::Vector{SymNumber} +end + +# Symbolic matrix wrapper (now stores SymNumber elements) +struct SymMatrix <: AbstractSymArray{2} + data::Matrix{SymNumber} +end + +# Generic N-dimensional wrapper (for future extensibility) +struct SymArray{N} <: AbstractSymArray{N} + data::Array{SymNumber, N} +end + +# ============================================================================= +# 2. CONSTRUCTORS +# ============================================================================= + +# Convenience constructors +# Note: The struct definition automatically provides: +# SymNumber{T}(value::T) where T<:AbstractNode (inner constructor) +# SymNumber(x::AbstractNode) (outer constructor) +# We only need to add constructors for non-AbstractNode types: +SymNumber(x::Number) = SymNumber{Null{typeof(x)}}(Null(x)) + +# Construct vectors from AbstractNode arrays +function SymVector(v::AbstractVector{<:AbstractNode}) + return SymVector([SymNumber(x) for x in v]) +end + +# Construct matrices from AbstractNode arrays +function SymMatrix(m::AbstractMatrix{<:AbstractNode}) + return SymMatrix([SymNumber(x) for x in m]) +end + +# Construct from SymNumber arrays (already wrapped) +SymVector(v::AbstractVector{SymNumber}) = SymVector(collect(v)) +SymMatrix(m::AbstractMatrix{SymNumber}) = SymMatrix(collect(m)) + +# Convenience constructors with UndefInitializer +SymVector(::UndefInitializer, n::Int) = SymVector(Vector{SymNumber}(undef, n)) +SymMatrix(::UndefInitializer, m::Int, n::Int) = SymMatrix(Matrix{SymNumber}(undef, m, n)) +SymArray{N}(::UndefInitializer, dims::Vararg{Int,N}) where {N} = SymArray{N}(Array{SymNumber,N}(undef, dims...)) + +# ============================================================================= +# 3. UNWRAP HELPERS +# ============================================================================= + +# Extract underlying AbstractNode from SymNumber +unwrap_scalar(x::SymNumber) = x.value +unwrap_scalar(x::AbstractNode) = x # Pass-through if already unwrapped +unwrap_scalar(x::Number) = Null(x) # Convert numbers to Null + +# Extract underlying data from array wrappers (returns array of AbstractNode) +unwrap(v::SymVector) = [unwrap_scalar(x) for x in v.data] +unwrap(m::SymMatrix) = [unwrap_scalar(x) for x in m.data] +unwrap(a::SymArray) = [unwrap_scalar(x) for x in a.data] +unwrap(x::SymNumber) = unwrap_scalar(x) +unwrap(x::AbstractNode) = x # Pass-through +unwrap(x::Number) = x # Pass-through +unwrap(x::AbstractArray) = x # Pass-through + +# ============================================================================= +# 4. SYMNUMBER OPERATIONS (NO TYPE PIRACY - dispatching on OUR type) +# ============================================================================= + +# Arithmetic operations +Base.:+(x::SymNumber, y::SymNumber) = SymNumber(unwrap_scalar(x) + unwrap_scalar(y)) +Base.:-(x::SymNumber, y::SymNumber) = SymNumber(unwrap_scalar(x) - unwrap_scalar(y)) +Base.:*(x::SymNumber, y::SymNumber) = SymNumber(unwrap_scalar(x) * unwrap_scalar(y)) +Base.:/(x::SymNumber, y::SymNumber) = SymNumber(unwrap_scalar(x) / unwrap_scalar(y)) +Base.:^(x::SymNumber, p::Real) = SymNumber(unwrap_scalar(x)^p) +Base.:^(x::SymNumber, y::SymNumber) = SymNumber(unwrap_scalar(x)^unwrap_scalar(y)) + +Base.:-(x::SymNumber) = SymNumber(-unwrap_scalar(x)) +Base.:+(x::SymNumber) = x + +# Mixed operations with numbers +Base.:+(x::SymNumber, y::Number) = SymNumber(unwrap_scalar(x) + y) +Base.:+(x::Number, y::SymNumber) = SymNumber(x + unwrap_scalar(y)) +Base.:-(x::SymNumber, y::Number) = SymNumber(unwrap_scalar(x) - y) +Base.:-(x::Number, y::SymNumber) = SymNumber(x - unwrap_scalar(y)) +Base.:*(x::SymNumber, y::Number) = SymNumber(unwrap_scalar(x) * y) +Base.:*(x::Number, y::SymNumber) = SymNumber(x * unwrap_scalar(y)) +Base.:/(x::SymNumber, y::Number) = SymNumber(unwrap_scalar(x) / y) +Base.:/(x::Number, y::SymNumber) = SymNumber(x / unwrap_scalar(y)) + +# Math functions +Base.abs(x::SymNumber) = SymNumber(abs(unwrap_scalar(x))) +Base.abs2(x::SymNumber) = SymNumber(abs2(unwrap_scalar(x))) +Base.sqrt(x::SymNumber) = SymNumber(sqrt(unwrap_scalar(x))) +Base.exp(x::SymNumber) = SymNumber(exp(unwrap_scalar(x))) +Base.log(x::SymNumber) = SymNumber(log(unwrap_scalar(x))) +Base.sin(x::SymNumber) = SymNumber(sin(unwrap_scalar(x))) +Base.cos(x::SymNumber) = SymNumber(cos(unwrap_scalar(x))) +Base.tan(x::SymNumber) = SymNumber(tan(unwrap_scalar(x))) + +# Identity elements +Base.zero(::Type{<:SymNumber}) = SymNumber(Null(nothing)) +Base.zero(::SymNumber) = SymNumber(Null(nothing)) +Base.one(::Type{<:SymNumber}) = SymNumber(Null(1)) +Base.one(::SymNumber) = SymNumber(Null(1)) + +# Adjoint/transpose +Base.adjoint(x::SymNumber) = x +Base.transpose(x::SymNumber) = x +Base.conj(x::SymNumber) = x + +# Scalar properties +Base.broadcastable(x::SymNumber) = Ref(x) +Base.iterate(::SymNumber) = nothing +Base.length(::SymNumber) = 1 +Base.size(::SymNumber) = () +Base.ndims(::SymNumber) = 0 +Base.ndims(::Type{<:SymNumber}) = 0 +Base.IteratorSize(::Type{<:SymNumber}) = Base.HasShape{0}() + +# Promotion and conversion +Base.promote_rule(::Type{<:SymNumber}, ::Type{<:Number}) = SymNumber +Base.promote_rule(::Type{SymNumber{S}}, ::Type{SymNumber{T}}) where {S,T} = SymNumber +Base.convert(::Type{SymNumber}, x::Number) = SymNumber(Null(x)) +Base.convert(::Type{SymNumber}, x::SymNumber) = x +Base.convert(::Type{SymNumber}, x::AbstractNode) = SymNumber(x) +Base.convert(::Type{SymNumber{T}}, x::SymNumber) where {T} = x # Allow conversion between SymNumber types +Base.convert(::Type{SymNumber{T}}, x::AbstractNode) where {T<:AbstractNode} = SymNumber(x) + +# ============================================================================= +# 5. ARRAY INTERFACE +# ============================================================================= + +# Size and shape +Base.size(v::SymVector) = size(v.data) +Base.size(m::SymMatrix) = size(m.data) +Base.size(a::SymArray) = size(a.data) + +Base.length(v::SymVector) = length(v.data) +Base.length(m::SymMatrix) = length(m.data) +Base.length(a::SymArray) = length(a.data) + +Base.axes(v::SymVector) = axes(v.data) +Base.axes(m::SymMatrix) = axes(m.data) +Base.axes(a::SymArray) = axes(a.data) + +# Index style +Base.IndexStyle(::Type{<:AbstractSymArray}) = IndexLinear() + +# Indexing - Returns SymNumber (wrapped scalar) +Base.getindex(v::SymVector, i::Int) = v.data[i] +Base.getindex(m::SymMatrix, i::Int, j::Int) = m.data[i, j] +Base.getindex(m::SymMatrix, i::Int) = m.data[i] # Linear indexing +Base.getindex(a::SymArray, i::Int) = a.data[i] +Base.getindex(a::SymArray{N}, I::Vararg{Int,N}) where {N} = a.data[I...] + +# Setindex - accepts both SymNumber and AbstractNode +Base.setindex!(v::SymVector, val::SymNumber, i::Int) = (v.data[i] = val) +Base.setindex!(v::SymVector, val::AbstractNode, i::Int) = (v.data[i] = SymNumber(val)) +Base.setindex!(m::SymMatrix, val::SymNumber, i::Int, j::Int) = (m.data[i, j] = val) +Base.setindex!(m::SymMatrix, val::AbstractNode, i::Int, j::Int) = (m.data[i, j] = SymNumber(val)) +Base.setindex!(m::SymMatrix, val::SymNumber, i::Int) = (m.data[i] = val) +Base.setindex!(m::SymMatrix, val::AbstractNode, i::Int) = (m.data[i] = SymNumber(val)) +Base.setindex!(a::SymArray, val::SymNumber, i::Int) = (a.data[i] = val) +Base.setindex!(a::SymArray, val::AbstractNode, i::Int) = (a.data[i] = SymNumber(val)) +Base.setindex!(a::SymArray{N}, val::SymNumber, I::Vararg{Int,N}) where {N} = (a.data[I...] = val) +Base.setindex!(a::SymArray{N}, val::AbstractNode, I::Vararg{Int,N}) where {N} = (a.data[I...] = SymNumber(val)) + +# Slicing - returns wrapped types +Base.getindex(v::SymVector, r::AbstractRange) = SymVector(v.data[r]) +Base.getindex(v::SymVector, inds::AbstractVector{Int}) = SymVector(v.data[inds]) + +Base.getindex(m::SymMatrix, ::Colon, j::Int) = SymVector(m.data[:, j]) +Base.getindex(m::SymMatrix, i::Int, ::Colon) = SymVector(m.data[i, :]) +Base.getindex(m::SymMatrix, r1::AbstractRange, r2::AbstractRange) = SymMatrix(m.data[r1, r2]) + +# Similar - returns wrapped type +Base.similar(v::SymVector) = SymVector(similar(v.data)) +Base.similar(v::SymVector, ::Type{SymNumber}) = SymVector(similar(v.data)) +Base.similar(v::SymVector, ::Type{SymNumber}, dims::Tuple{Vararg{Int}}) = + SymVector(similar(v.data, SymNumber, dims)) + +Base.similar(m::SymMatrix) = SymMatrix(similar(m.data)) +Base.similar(m::SymMatrix, ::Type{SymNumber}) = SymMatrix(similar(m.data)) +Base.similar(m::SymMatrix, ::Type{SymNumber}, dims::Tuple{Vararg{Int}}) = + SymMatrix(similar(m.data, SymNumber, dims)) + +Base.similar(a::SymArray{N}) where {N} = SymArray{N}(similar(a.data)) +Base.similar(a::SymArray{N}, ::Type{SymNumber}) where {N} = SymArray{N}(similar(a.data)) + +# ============================================================================= +# 6. CONVERSION +# ============================================================================= + +# Convert AbstractArray{<:AbstractNode} to wrapped types +Base.convert(::Type{SymVector}, v::AbstractVector{<:AbstractNode}) = SymVector(v) +Base.convert(::Type{SymMatrix}, m::AbstractMatrix{<:AbstractNode}) = SymMatrix(m) + +# Convert wrapped types back to regular arrays +Base.convert(::Type{Vector{<:AbstractNode}}, v::SymVector) = unwrap(v) +Base.convert(::Type{Matrix{<:AbstractNode}}, m::SymMatrix) = unwrap(m) + +# ============================================================================= +# 7. SYMBOLIC ARITHMETIC HELPERS +# ============================================================================= + +# Symbolic addition with Null(nothing) as additive identity +function sym_add(a::SymNumber, b::SymNumber) + av = unwrap_scalar(a) + bv = unwrap_scalar(b) + av isa Null{Nothing} && return b + bv isa Null{Nothing} && return a + return SymNumber(av + bv) +end + +sym_add(a::AbstractNode, b::AbstractNode) = sym_add(SymNumber(a), SymNumber(b)) +sym_add(a::SymNumber, b::AbstractNode) = sym_add(a, SymNumber(b)) +sym_add(a::AbstractNode, b::SymNumber) = sym_add(SymNumber(a), b) + +# Symbolic multiplication +sym_mul(a::SymNumber, b::SymNumber) = a * b +sym_mul(a::AbstractNode, b::AbstractNode) = SymNumber(a) * SymNumber(b) +sym_mul(a::SymNumber, b::AbstractNode) = a * SymNumber(b) +sym_mul(a::AbstractNode, b::SymNumber) = SymNumber(a) * b +sym_mul(a::SymNumber, b::Number) = a * b +sym_mul(a::Number, b::SymNumber) = a * b + +# ============================================================================= +# 8. MATRIX-VECTOR PRODUCTS +# ============================================================================= + +# Numeric matrix x Symbolic vector -> SymVector +function Base.:*(A::AbstractMatrix{<:Number}, x::SymVector) + m, n = size(A) + @assert n == length(x) "Dimension mismatch: matrix has $(n) columns, vector has $(length(x)) elements" + + result = SymVector(undef, m) + for i in 1:m + acc = zero(SymNumber) + for j in 1:n + acc = sym_add(acc, sym_mul(A[i, j], x[j])) + end + result[i] = acc + end + return result +end + +# Disambiguation: Diagonal x SymVector +function Base.:*(D::LinearAlgebra.Diagonal{<:Number}, x::SymVector) + n = length(D.diag) + @assert n == length(x) "Dimension mismatch: diagonal has $(n) elements, vector has $(length(x)) elements" + + result = SymVector(undef, n) + for i in 1:n + result[i] = SymNumber(D.diag[i] * unwrap_scalar(x[i])) + end + return result +end + +# Symbolic matrix x Numeric vector -> SymVector +function Base.:*(A::SymMatrix, x::AbstractVector{<:Number}) + m, n = size(A) + @assert n == length(x) "Dimension mismatch: matrix has $(n) columns, vector has $(length(x)) elements" + + result = SymVector(undef, m) + for i in 1:m + acc = zero(SymNumber) + for j in 1:n + acc = sym_add(acc, sym_mul(A[i, j], x[j])) + end + result[i] = acc + end + return result +end + +# Symbolic matrix x Symbolic vector -> SymVector +function Base.:*(A::SymMatrix, x::SymVector) + m, n = size(A) + @assert n == length(x) "Dimension mismatch: matrix has $(n) columns, vector has $(length(x)) elements" + + result = SymVector(undef, m) + for i in 1:m + acc = zero(SymNumber) + for j in 1:n + acc = sym_add(acc, sym_mul(A[i, j], x[j])) + end + result[i] = acc + end + return result +end + +# ============================================================================= +# 9. ROW VECTOR x MATRIX (via Adjoint) +# ============================================================================= + +# Symbolic row x Numeric matrix -> Adjoint{SymVector} +function Base.:*(x::LinearAlgebra.Adjoint{SymNumber, <:SymVector}, A::AbstractMatrix{<:Number}) + xp = parent(x) + n, p = size(A) + @assert length(xp) == n "Dimension mismatch: vector has $(length(xp)) elements, matrix has $(n) rows" + + result = SymVector(undef, p) + for j in 1:p + acc = zero(SymNumber) + for i in 1:n + acc = sym_add(acc, sym_mul(xp[i], A[i, j])) + end + result[j] = acc + end + return adjoint(result) +end + +# Numeric row x Symbolic matrix -> Adjoint{SymVector} +# Use DenseVector to avoid ambiguity with other Adjoint methods +function Base.:*(x::LinearAlgebra.Adjoint{<:Number, <:DenseVector{<:Number}}, A::SymMatrix) + xp = parent(x) + n, p = size(A) + @assert length(xp) == n "Dimension mismatch: vector has $(length(xp)) elements, matrix has $(n) rows" + + result = SymVector(undef, p) + for j in 1:p + acc = zero(SymNumber) + for i in 1:n + acc = sym_add(acc, sym_mul(xp[i], A[i, j])) + end + result[j] = acc + end + return adjoint(result) +end + +# Disambiguation: Generic Adjoint x SymMatrix +function Base.:*(x::LinearAlgebra.Adjoint{<:Number, <:AbstractVector}, A::SymMatrix) + # Delegate to the DenseVector version by converting parent to array + xp_dense = collect(parent(x)) + return adjoint(xp_dense) * A +end + +# Symbolic row x Symbolic matrix -> Adjoint{SymVector} +function Base.:*(x::LinearAlgebra.Adjoint{SymNumber, <:SymVector}, A::SymMatrix) + xp = parent(x) + n, p = size(A) + @assert length(xp) == n "Dimension mismatch: vector has $(length(xp)) elements, matrix has $(n) rows" + + result = SymVector(undef, p) + for j in 1:p + acc = zero(SymNumber) + for i in 1:n + acc = sym_add(acc, sym_mul(xp[i], A[i, j])) + end + result[j] = acc + end + return adjoint(result) +end + +# ============================================================================= +# 10. MATRIX x MATRIX PRODUCTS +# ============================================================================= + +# Numeric matrix x Symbolic matrix -> SymMatrix +function Base.:*(A::AbstractMatrix{<:Number}, B::SymMatrix) + m, n = size(A) + n2, p = size(B) + @assert n == n2 "Dimension mismatch: first matrix has $(n) columns, second has $(n2) rows" + + result = SymMatrix(undef, m, p) + for i in 1:m + for j in 1:p + acc = zero(SymNumber) + for k in 1:n + acc = sym_add(acc, sym_mul(A[i, k], B[k, j])) + end + result[i, j] = acc + end + end + return result +end + +# Symbolic matrix x Numeric matrix -> SymMatrix +function Base.:*(A::SymMatrix, B::AbstractMatrix{<:Number}) + m, n = size(A) + n2, p = size(B) + @assert n == n2 "Dimension mismatch: first matrix has $(n) columns, second has $(n2) rows" + + result = SymMatrix(undef, m, p) + for i in 1:m + for j in 1:p + acc = zero(SymNumber) + for k in 1:n + acc = sym_add(acc, sym_mul(A[i, k], B[k, j])) + end + result[i, j] = acc + end + end + return result +end + +# Symbolic matrix x Symbolic matrix -> SymMatrix +function Base.:*(A::SymMatrix, B::SymMatrix) + m, n = size(A) + n2, p = size(B) + @assert n == n2 "Dimension mismatch: first matrix has $(n) columns, second has $(n2) rows" + + result = SymMatrix(undef, m, p) + for i in 1:m + for j in 1:p + acc = zero(SymNumber) + for k in 1:n + acc = sym_add(acc, sym_mul(A[i, k], B[k, j])) + end + result[i, j] = acc + end + end + return result +end + +# ============================================================================= +# 11. DOT PRODUCTS (return SymNumber scalar) +# ============================================================================= + +# Internal implementation of dot product +function _sym_dot(x, y, nx::Int, ny::Int) + @assert nx == ny "Dimension mismatch: vectors have lengths $(nx) and $(ny)" + acc = zero(SymNumber) + for i in 1:nx + acc = sym_add(acc, sym_mul(x[i], y[i])) + end + return acc +end + +# Symbolic dot Symbolic -> SymNumber +function LinearAlgebra.dot(x::SymVector, y::SymVector) + return _sym_dot(x, y, length(x), length(y)) +end + +# Numeric dot Symbolic -> SymNumber +# Use DenseVector to be more specific and avoid ambiguity with SparseArrays +function LinearAlgebra.dot(x::DenseVector{<:Number}, y::SymVector) + return _sym_dot(x, y, length(x), length(y)) +end + +# Symbolic dot Numeric -> SymNumber +# Use DenseVector to be more specific and avoid ambiguity with SparseArrays +function LinearAlgebra.dot(x::SymVector, y::DenseVector{<:Number}) + return _sym_dot(x, y, length(x), length(y)) +end + +# ============================================================================= +# 11a. SPARSE VECTOR DISAMBIGUATION +# ============================================================================= +# These methods resolve ambiguities with SparseArrays.dot methods. +# SparseArrays defines: +# dot(::AbstractCompressedVector, ::AbstractVector) +# dot(::AbstractVector, ::AbstractCompressedVector) +# We need explicit methods for sparse vectors with SymVector. + +# Type alias for SparseArrays' compressed vector types +const SparseVecLike = Union{ + AbstractCompressedVector, + SubArray{<:Any, 1, <:AbstractSparseMatrixCSC, Tuple{Base.Slice{Base.OneTo{Int}}, Int}, false}, + SubArray{<:Any, 1, <:AbstractSparseVector, Tuple{Base.Slice{Base.OneTo{Int}}}, false} +} + +# Sparse vector dot SymVector -> SymNumber (disambiguation) +function LinearAlgebra.dot(x::SparseVecLike, y::SymVector) + return _sym_dot(x, y, length(x), length(y)) +end + +# SymVector dot Sparse vector -> SymNumber (disambiguation) +function LinearAlgebra.dot(x::SymVector, y::SparseVecLike) + return _sym_dot(x, y, length(x), length(y)) +end + +# ============================================================================= +# 11b. ADDITIONAL DISAMBIGUATION FOR ADJOINT/TRANSPOSE +# ============================================================================= + +# Disambiguate: Adjoint{SymNumber, SymVector} * Adjoint{Number, Transpose{T, Vector}} +# This handles x' * (v')' cases +function Base.:*(x::LinearAlgebra.Adjoint{SymNumber, <:SymVector}, + y::LinearAlgebra.Adjoint{<:Number, <:LinearAlgebra.Transpose{<:Any, <:AbstractVector}}) + # x' * (v')' = x' * v = dot(x, v) + # Unwrap the double transpose and compute dot product + return dot(parent(x), parent(parent(y))) +end + +# Disambiguate: (Adjoint or Transpose of numeric vector) * SymVector +# This handles row vector * column vector -> scalar +function Base.:*(x::Union{LinearAlgebra.Adjoint{<:Number, <:AbstractVector}, + LinearAlgebra.Transpose{<:Number, <:AbstractVector}}, + y::SymVector) + # Row vector * column vector = dot product (scalar) + return dot(parent(x), y) +end + +# Disambiguate: Transpose{Number, Vector} * SymMatrix +# This handles row vector * matrix -> row vector +function Base.:*(x::LinearAlgebra.Transpose{<:Number, <:AbstractVector}, A::SymMatrix) + # Convert transpose to adjoint and delegate + # For numeric vectors, transpose = adjoint + return adjoint(parent(x)) * A +end + +# ============================================================================= +# 12. INNER PRODUCTS VIA ADJOINT (x' * y) +# ============================================================================= + +# These delegate to dot products, which return SymNumber + +function Base.:*(x::LinearAlgebra.Adjoint{SymNumber, <:SymVector}, y::SymVector) + return dot(parent(x), y) +end + +function Base.:*(x::LinearAlgebra.Adjoint{SymNumber, <:SymVector}, y::DenseVector{<:Number}) + return dot(parent(x), y) +end + +function Base.:*(x::LinearAlgebra.Adjoint{<:Number, <:DenseVector{<:Number}}, y::SymVector) + return dot(parent(x), y) +end + +# Disambiguation with LinearAlgebra.Transpose +function Base.:*(x::LinearAlgebra.Adjoint{SymNumber, <:SymVector}, y::LinearAlgebra.Transpose{<:Number, <:LinearAlgebra.Adjoint{<:Any, <:AbstractVector}}) + # This is an edge case: SymVector' * (v')' + # Just delegate to the unwrapped version + return x * parent(parent(y)) +end + +# ============================================================================= +# 13. MATRIX ADJOINT / TRANSPOSE +# ============================================================================= + +# Vector adjoint returns Adjoint wrapper (for x' * A and x' * y syntax) +Base.adjoint(v::SymVector) = LinearAlgebra.Adjoint(v) +Base.transpose(v::SymVector) = LinearAlgebra.Transpose(v) + +# Matrix adjoint/transpose return wrapped SymMatrix +function Base.adjoint(A::SymMatrix) + # For symbolic (assumed real), adjoint = transpose + return SymMatrix(permutedims(A.data)) +end + +function Base.transpose(A::SymMatrix) + return SymMatrix(permutedims(A.data)) +end + +# Adjoint of adjoint returns the original +Base.adjoint(x::LinearAlgebra.Adjoint{SymNumber, <:SymVector}) = parent(x) +Base.transpose(x::LinearAlgebra.Transpose{SymNumber, <:SymVector}) = parent(x) + +# ============================================================================= +# 14. NORMS (return SymNumber scalar) +# ============================================================================= + +# norm_sqr for vectors +function LinearAlgebra.norm_sqr(x::SymVector) + return dot(x, x) +end + +# L2 norm (default) +function LinearAlgebra.norm(x::SymVector) + return sqrt(norm_sqr(x)) +end + +# Lp norm +function LinearAlgebra.norm(x::SymVector, p::Real) + if p == 2 + return norm(x) + elseif p == 1 + # L1 norm: sum of absolute values + acc = zero(SymNumber) + for i in 1:length(x) + acc = sym_add(acc, abs(x[i])) + end + return acc + elseif p == Inf + error("Infinity norm not yet implemented for symbolic vectors") + else + # General Lp norm + acc = zero(SymNumber) + for i in 1:length(x) + acc = sym_add(acc, abs(x[i])^p) + end + return acc^(1/p) + end +end + +# Frobenius norm for matrices +function LinearAlgebra.norm(A::SymMatrix) + m, n = size(A) + acc = zero(SymNumber) + for i in 1:m + for j in 1:n + acc = sym_add(acc, abs2(A[i, j])) + end + end + return sqrt(acc) +end + +# ============================================================================= +# 15. BROADCASTING +# ============================================================================= + +# Custom broadcast style for symbolic wrappers +struct SymArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end + +SymArrayStyle(::Val{N}) where N = SymArrayStyle{N}() +SymArrayStyle{M}(::Val{N}) where {M,N} = SymArrayStyle{N}() + +# Register style for wrapper types +Base.BroadcastStyle(::Type{<:SymVector}) = SymArrayStyle{1}() +Base.BroadcastStyle(::Type{<:SymMatrix}) = SymArrayStyle{2}() +Base.BroadcastStyle(::Type{<:SymArray{N}}) where {N} = SymArrayStyle{N}() + +# SymArrayStyle wins over DefaultArrayStyle +Base.BroadcastStyle(::SymArrayStyle{N}, ::DefaultArrayStyle{M}) where {N,M} = + SymArrayStyle{max(N,M)}() +Base.BroadcastStyle(::DefaultArrayStyle{M}, ::SymArrayStyle{N}) where {N,M} = + SymArrayStyle{max(N,M)}() + +# Combine SymArrayStyles +Base.BroadcastStyle(::SymArrayStyle{N}, ::SymArrayStyle{M}) where {N,M} = + SymArrayStyle{max(N,M)}() + +# Allocate output for broadcast - returns wrapped type +function Base.similar(bc::Broadcasted{SymArrayStyle{1}}, ::Type{SymNumber}) + sz = map(length, axes(bc)) + return SymVector(undef, sz[1]) +end + +function Base.similar(bc::Broadcasted{SymArrayStyle{2}}, ::Type{SymNumber}) + sz = map(length, axes(bc)) + return SymMatrix(undef, sz...) +end + +function Base.similar(bc::Broadcasted{SymArrayStyle{N}}, ::Type{SymNumber}) where {N} + sz = map(length, axes(bc)) + return SymArray{N}(undef, sz...) +end + +# Materialize broadcast +function Base.copy(bc::Broadcasted{SymArrayStyle{1}}) + result = similar(bc, SymNumber) + @inbounds for I in CartesianIndices(result.data) + val = bc[I] + result.data[I] = val isa SymNumber ? val : SymNumber(val) + end + return result +end + +function Base.copy(bc::Broadcasted{SymArrayStyle{2}}) + result = similar(bc, SymNumber) + @inbounds for I in CartesianIndices(result.data) + val = bc[I] + result.data[I] = val isa SymNumber ? val : SymNumber(val) + end + return result +end + +function Base.copy(bc::Broadcasted{SymArrayStyle{N}}) where N + result = similar(bc, SymNumber) + @inbounds for I in CartesianIndices(result.data) + val = bc[I] + result.data[I] = val isa SymNumber ? val : SymNumber(val) + end + return result +end + +# In-place broadcast +function Base.copyto!(dest::SymVector, bc::Broadcasted{SymArrayStyle{1}}) + @inbounds for I in CartesianIndices(dest.data) + val = bc[I] + dest.data[I] = val isa SymNumber ? val : SymNumber(val) + end + return dest +end + +function Base.copyto!(dest::SymMatrix, bc::Broadcasted{SymArrayStyle{2}}) + @inbounds for I in CartesianIndices(dest.data) + val = bc[I] + dest.data[I] = val isa SymNumber ? val : SymNumber(val) + end + return dest +end + +function Base.copyto!(dest::SymArray{N}, bc::Broadcasted{SymArrayStyle{N}}) where {N} + @inbounds for I in CartesianIndices(dest.data) + val = bc[I] + dest.data[I] = val isa SymNumber ? val : SymNumber(val) + end + return dest +end diff --git a/src/onepass.jl b/src/onepass.jl index 5e3aa3c..e6df46b 100644 --- a/src/onepass.jl +++ b/src/onepass.jl @@ -96,19 +96,6 @@ function e_prefix!(p) return nothing end -# Utils - -""" -$(TYPEDSIGNATURES) - -Generate a fresh symbol by concatenating the given components and a -`gensym()` suffix. - -This is used throughout the parser to create unique internal names that -do not collide with user-defined identifiers. -""" -__symgen(s...) = Symbol(s..., gensym()) - """ $(TYPEDEF) @@ -191,37 +178,14 @@ case of an exception, prints the originating line number and source text before rethrowing. """ __wrap(e, n, line) = quote - local ex try $e - catch ex + catch println("Line ", $n, ": ", $line) - throw(ex) + rethrow() end end -""" -$(TYPEDSIGNATURES) - -Return `true` if `x` represents a range. - -This predicate is specialised for `AbstractRange` values and for -expressions of the form `i:j` or `i:p:j`. -""" -is_range(x) = false -is_range(x::T) where {T<:AbstractRange} = true -is_range(x::Expr) = (x.head == :call) && (x.args[1] == :(:)) - -""" -$(TYPEDSIGNATURES) - -Return `x` itself if it is a range, or a one-element array `[x]`. - -This is a normalisation helper used when interpreting constraint -indices. -""" -as_range(x) = is_range(x) ? x : [x] - # Main code """ @@ -580,7 +544,10 @@ function p_state_exa!(p, p_ocp, x, n, xx; components_names=nothing) )) code = __wrap(code, p.lnum, p.line) dyn_con = Symbol(:dyn_con, x) # name for the constraints associated with the dynamics - code = :($x = $code; $dyn_con = Vector{$pref.Constraint}(undef, $n)) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope) + code = quote + $x = $code + $dyn_con = Vector{$pref.Constraint}(undef, $n) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope) + end return code end @@ -696,7 +663,7 @@ function p_constraint_fun!(p, p_ocp, e1, e2, e3, c_type, label) (:variable_range, rg) => :($pref.constraint!( $p_ocp, :variable; rg=($rg), lb=($e1), ub=($e3), label=($llabel) )) - :state_fun || control_fun || :mixed => begin # now all treated as path + :state_fun || :control_fun || :mixed => begin # now all treated as path fun = __symgen(:fun) xt = __symgen(:xt) ut = __symgen(:ut) @@ -727,17 +694,20 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime x0 = __symgen(:x0) xf = __symgen(:xf) + k = __symgen(:k) e2 = replace_call(e2, p.x, p.t0, x0) e2 = replace_call(e2, p.x, p.tf, xf) e2 = subs2(e2, x0, p.x, 0) + e2 = subs(e2, x0, :([$(p.x)[$k, 0] for $k ∈ 1:$(p.dim_x)])) e2 = subs2(e2, xf, p.x, :grid_size) - concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1), ucon=($e3)))) + e2 = subs(e2, xf, :([$(p.x)[$k, grid_size] for $k ∈ 1:$(p.dim_x)])) + concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1[1]), ucon=($e3[1])))) # todo: e1/3[1] will be e1/3[k] when vectorised over dim end (:initial, rg) => begin if isnothing(rg) rg = :(1:($(p.dim_x))) # x(t0) implies rg == nothing but means x[1:p.dim_x](t0) e2 = subs(e2, p.x, :($(p.x)[$rg])) - elseif !is_range(rg) + else rg = as_range(rg) end code = :( @@ -756,8 +726,8 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) if isnothing(rg) rg = :(1:($(p.dim_x))) e2 = subs(e2, p.x, :($(p.x)[$rg])) - elseif !is_range(rg) - rg = as_range(rg) + else + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -775,8 +745,8 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) if isnothing(rg) rg = :(1:($(p.dim_v))) e2 = subs(e2, p.v, :($(p.v)[$rg])) - elseif !is_range(rg) - rg = as_range(rg) + else + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code_box = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -791,10 +761,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) end (:state_range, rg) => begin if isnothing(rg) - rg = :(1:($(p.dim_x))) - e2 = subs(e2, p.x, :($(p.x)[$rg])) - elseif !is_range(rg) - rg = as_range(rg) + rg = :(1:($(p.dim_x))) # NB. no need to update e2 (unused) here + else + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code_box = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -809,10 +778,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) end (:control_range, rg) => begin if isnothing(rg) - rg = :(1:($(p.dim_u))) - e2 = subs(e2, p.u, :($(p.u)[$rg])) - elseif !is_range(rg) - rg = as_range(rg) + rg = :(1:($(p.dim_u))) # NB. no need to update e2 (unused here) + else + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code_box = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -825,19 +793,22 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) p.box_u = concat(p.box_u, code_box) # not __wrapped since contains definition of l_u/u_u :() end - :state_fun || control_fun || :mixed => begin + :state_fun || :control_fun || :mixed => begin code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime xt = __symgen(:xt) ut = __symgen(:ut) e2 = replace_call(e2, [p.x, p.u], p.t, [xt, ut]) j = __symgen(:j) + k = __symgen(:k) e2 = subs2(e2, xt, p.x, j) + e2 = subs(e2, xt, :([$(p.x)[$k, $j] for $k ∈ 1:$(p.dim_x)])) e2 = subs2(e2, ut, p.u, j) + e2 = subs(e2, ut, :([$(p.u)[$k, $j] for $k ∈ 1:$(p.dim_u)])) e2 = subs(e2, p.t, :($(p.t0) + $j * $(p.dt))) concat( code, :($pref.constraint( - $p_ocp, $e2 for $j in 0:grid_size; lcon=($e1), ucon=($e3) + $p_ocp, $e2 for $j in 0:grid_size; lcon=($e1[1]), ucon=($e3[1]) )), ) end @@ -931,26 +902,33 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e) j1 = __symgen(:j) j2 = :($j1 + 1) j12 = :($j1 + 0.5) + k = __symgen(:k) ej1 = subs2(e, xt, p.x, j1) + ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k ∈ 1:$(p.dim_x)])) ej1 = subs2(ej1, ut, p.u, j1) + ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt))) ej2 = subs2(e, xt, p.x, j2) + ej2 = subs(ej2, xt, :([$(p.x)[$k, $j2] for $k ∈ 1:$(p.dim_x)])) ej2 = subs2(ej2, ut, p.u, j2) + ej2 = subs(ej2, ut, :([$(p.u)[$k, $j2] for $k ∈ 1:$(p.dim_u)])) ej2 = subs(ej2, p.t, :($(p.t0) + $j2 * $(p.dt))) - ej12 = subs5(e, xt, p.x, j1) + ej12 = subs2m(e, xt, p.x, j1) + ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k ∈ 1:$(p.dim_x)])) ej12 = subs2(ej12, ut, p.u, j1) + ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt))) dxij = :($(p.x)[$i, $j2] - $(p.x)[$i, $j1]) code = quote if scheme == :euler - $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1)) + $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:grid_size-1) elseif scheme ∈ (:euler_implicit, :euler_b) # euler_b is deprecated - $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:(grid_size - 1)) + $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:grid_size-1) elseif scheme == :midpoint - $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1)) + $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:grid_size-1) elseif scheme ∈ (:trapeze, :trapezoidal) # trapezoidal is deprecated $pref.constraint( - $p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:(grid_size - 1) + $p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:grid_size-1 ) else throw( @@ -1001,22 +979,27 @@ function p_lagrange_exa!(p, p_ocp, e, type) j1 = __symgen(:j) j2 = :($j1 + 1) j12 = :($j1 + 0.5) + k = __symgen(:k) ej1 = subs2(e, xt, p.x, j1) + ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k ∈ 1:$(p.dim_x)])) ej1 = subs2(ej1, ut, p.u, j1) + ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt))) - ej12 = subs5(e, xt, p.x, j1) + ej12 = subs2m(e, xt, p.x, j1) + ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k ∈ 1:$(p.dim_x)])) ej12 = subs2(ej12, ut, p.u, j1) + ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt))) code = quote if scheme == :euler - $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1)) + $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:grid_size-1) elseif scheme ∈ (:euler_implicit, :euler_b) # euler_b is deprecated $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size) elseif scheme == :midpoint - $pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1)) + $pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:grid_size-1) elseif scheme ∈ (:trapeze, :trapezoidal) # trapezoidal is deprecated $pref.objective($p_ocp, $(p.dt) * $ej1 / 2 for $j1 in (0, grid_size)) - $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:(grid_size - 1)) + $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size-1) else throw( "unknown numerical scheme: $scheme (possible choices are :euler, :euler_implicit, :midpoint, :trapeze)", @@ -1062,10 +1045,13 @@ function p_mayer_exa!(p, p_ocp, e, type) pref = prefix_exa() x0 = __symgen(:x0) xf = __symgen(:xf) + k = __symgen(:k) e = replace_call(e, p.x, p.t0, x0) e = replace_call(e, p.x, p.tf, xf) e = subs2(e, x0, p.x, 0) + e = subs(e, x0, :([$(p.x)[$k, 0] for $k ∈ 1:$(p.dim_x)])) e = subs2(e, xf, p.x, :grid_size) + e = subs(e, xf, :([$(p.x)[$k, grid_size] for $k ∈ 1:$(p.dim_x)])) # now, x[i](t0) has been replaced by x[i, 0] and x[i](tf) by x[i, grid_size] code = :($pref.objective($p_ocp, $e)) return __wrap(code, p.lnum, p.line) @@ -1133,7 +1119,7 @@ PARSING_FUN[:lagrange] = p_lagrange_fun! PARSING_FUN[:mayer] = p_mayer_fun! PARSING_FUN[:bolza] = p_bolza_fun! -# Summary of available parsing subfunctions (:fun backend) +# Summary of available parsing subfunctions (:exa backend) const PARSING_EXA = OrderedDict{Symbol,Function}() PARSING_EXA[:pragma] = p_pragma_exa! @@ -1295,7 +1281,7 @@ function def_fun(e; log=false) $p_ocp = $pref.PreModel() $code $pref.definition!($p_ocp, $ee) - $pref.time_dependence!($p_ocp; autonomous=$p.is_autonomous) + $pref.time_dependence!($p_ocp; autonomous=$p.is_autonomous) # not $(p.xxxx) as this info is known statically end if is_active_backend(:exa) @@ -1383,7 +1369,7 @@ function def_exa(e; log=false) $(p.box_u) # lvar and uvar for control $(p.box_v) # lvar and uvar for variable (after x and u for compatibility with CTDirect) $p_ocp = $pref.ExaCore( - base_type; backend=backend, minimize=($p.criterion == :min) + base_type; backend=backend, minimize=($p.criterion == :min) # not $(p.xxxx) as this info is known statically ) $code $dyn_check diff --git a/src/utils.jl b/src/utils.jl index 4e315d6..5dee3dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,6 +4,39 @@ """ $(TYPEDSIGNATURES) +Generate a fresh symbol by concatenating the given components and a +`gensym()` suffix. + +This is used throughout the parser to create unique internal names that +do not collide with user-defined identifiers. +""" +__symgen(s...) = Symbol(s..., gensym()) + +""" +$(TYPEDSIGNATURES) + +Return `true` if `x` represents a range. + +This predicate is specialised for `AbstractRange` values and for +expressions of the form `i:j` or `i:p:j`. +""" +is_range(x) = false +is_range(x::T) where {T<:AbstractRange} = true +is_range(x::Expr) = (x.head == :call) && (x.args[1] == :(:)) + +""" +$(TYPEDSIGNATURES) + +Return `x` itself if it is a range, or a one-element array `[x]`. + +This is a normalisation helper used when interpreting constraint +indices. +""" +as_range(x) = is_range(x) ? x : :(($x):($x)) + +""" +$(TYPEDSIGNATURES) + Expr iterator: apply `_Expr` to nodes and `f` to leaves of the AST. # Example @@ -64,39 +97,47 @@ end """ $(TYPEDSIGNATURES) -Substitute x[i] by y[i, j], whatever i, in e. See also: subs5. +Substitute occurrences of symbol `x` in expression `e` with indexed access to `y` at time index `j`. +Handles two patterns: +- `x[i]` (scalar index) → `y[i, j]` +- `x[1:3]` (range index) → `[y[k, j] for k ∈ 1:3]` + +See also: subs2m. # Examples ```@example +julia> # Scalar indexing julia> e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) -:(x0[1] * (2 * xf[3]) - cos(xf[2]) * (2 * x0[2])) - julia> subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) -julia> e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) -:(x0 * (2 * xf[3]) - cos(xf) * (2 * x0[2])) +julia> # Range indexing +julia> e = :(x0[1:3]) +julia> subs2(e, :x0, :x, 0; k = :k) +:([x[k, 0] for k ∈ 1:3]) +julia> # Bare symbols are not substituted +julia> e = :(x0 * 2xf[3]) julia> subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) -:(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) +:(x0 * (2 * x[3, N])) ``` """ -function subs2(e, x, y, j) +function subs2(e, x, y, j; k = __symgen(:k)) foo(x, y, j) = (h, args...) -> begin f = Expr(h, args...) @match f begin - :($xx[$i]) && if (xx == x) - end => :($y[$i, $j]) + :($xx[$rg]) && if ((xx == x) && is_range(rg)) end => :([$y[$k, $j] for $k ∈ $rg]) + :($xx[$i]) && if (xx == x) end => :($y[$i, $j]) _ => f end end - expr_it(e, foo(x, y, j), x -> x) + expr_it(e, foo(x, y, j), x -> x) end """ $(TYPEDSIGNATURES) -Substitute x[rg] by y[i, j], whatever rg, in e. +Substitute x[rg] by y[i, j], whatever rg, in e. (Note: rg is then expected to be used to loop on i.) # Examples ```@example @@ -125,59 +166,42 @@ end """ $(TYPEDSIGNATURES) -Substitute x[rg] by y[i], whatever rg, in e. +Substitute x[i] or x[rg] in e for the midpoint scheme: +- x[i] → (y[i, j] + y[i, j + 1]) / 2 (scalar indexing) +- x[rg] → [(y[k, j] + y[k, j + 1]) / 2 for k ∈ rg] (range indexing) -# Examples -```@example -julia> e = :(v[1:2:d] * 2xf[1:3]) -:(v[1:2:d] * (2 * xf[1:3])) +Bare symbols like x (without indexing) are NOT substituted. -julia> subs4(e, :v, :v, :i) -:(v[i] * (2 * xf[1:3])) - -julia> subs4(e, :xf, :xf, 1) -:(v[1:2:d] * (2 * xf[1])) -``` -""" -function subs4(e, x, y, i) # currently unused - foo(x, y, i) = (h, args...) -> begin - f = Expr(h, args...) - @match f begin - :($xx[$rg]) && if (xx == x) - end => :($y[$i]) - _ => f - end - end - expr_it(e, foo(x, y, i), x -> x) -end - -""" -$(TYPEDSIGNATURES) - -Substitute x[i] by (y[i, j] + y[i, j + 1]) / 2, whatever i, in e. See also: subs2. +See also: subs2. # Examples ```@example julia> e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) :(x0[1] * (2 * xf[3]) - cos(xf[2]) * (2 * x0[2])) -julia> subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) +julia> subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) :(((x[1, 0] + x[1, 0 + 1]) / 2) * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - cos((x[2, N] + x[2, N + 1]) / 2) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2))) julia> e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) :(x0 * (2 * xf[3]) - cos(xf) * (2 * x0[2])) -julia> subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) +julia> subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) :(x0 * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - cos(xf) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2))) + +julia> e = :(x0[1:3]) +:(x0[1:3]) + +julia> subs2m(e, :x0, :x, 0) +:([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:3]) ``` """ -function subs5(e, x, y, j) +function subs2m(e, x, y, j; k = __symgen(:k)) foo(x, y, j) = (h, args...) -> begin f = Expr(h, args...) @match f begin - :($xx[$i]) && if (xx == x) - end => :(($y[$i, $j] + $y[$i, $j + 1]) / 2) + :($xx[$rg]) && if ((xx == x) && is_range(rg)) end => :([($y[$k, $j] + $y[$k, $j + 1]) / 2 for $k ∈ $rg]) + :($xx[$i]) && if (xx == x) end => :(($y[$i, $j] + $y[$i, $j + 1]) / 2) _ => f end end diff --git a/test/Project.toml b/test/Project.toml index 7e9387c..4eff5f7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,16 +3,19 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd" CTModels = "34c4fa32-2049-4079-8329-de33c2a22e2d" +CTParser = "32681960-a1b1-40db-9bff-a1ca817385d1" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ExaModels = "1037b233-b668-4ce9-9b63-f9f681f55dd2" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MadNLP = "2621e9c9-9eb4-46b1-8089-e8c72242dfb6" MadNLPGPU = "d72a61cc-809d-412f-99be-fd81f4b8a598" NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +CTParser = {path = "/data/caillau/CTParser.jl"} + [compat] Aqua = "0.8" BenchmarkTools = "1" @@ -21,7 +24,6 @@ CTModels = "0.7" CUDA = "5" ExaModels = "0.9" Interpolations = "0.16" -KernelAbstractions = "0.9" MadNLP = "0.8" MadNLPGPU = "0.7" NLPModels = "0.21" diff --git a/test/runtests.jl b/test/runtests.jl index a38ceba..463a957 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,9 +5,8 @@ import CTParser: CTParser, subs, subs2, + subs2m, subs3, - subs4, - subs5, replace_call, has, concat, @@ -46,13 +45,14 @@ import CTModels: criterion, Model, get_build_examodel -using ExaModels: ExaModels +using ExaModels: ExaModels, AbstractNode +using MadNLP using MadNLP -using MadNLPGPU using CUDA using BenchmarkTools using Interpolations using NLPModels +#using LinearAlgebra: LinearAlgebra, dot, norm, norm_sqr macro ignore(e) return :() @@ -74,6 +74,8 @@ function default_tests() :onepass_fun_bis => true, :onepass_exa => true, :onepass_exa_bis => true, + #debug :exa_linalg => true, + :usecase8_gpu => false, # Temporary test for GPU-only use case 8 ) end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index c3fb35c..f7b8f5d 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,13 +1,8 @@ function test_aqua() @testset "Aqua.jl" begin - Aqua.test_all( - CTParser; - ambiguities=false, - #stale_deps=(ignore=[:MLStyle],), - deps_compat=(ignore=[:LinearAlgebra, :Unicode],), - piracies=true, - ) - # do not warn about ambiguities in dependencies - Aqua.test_ambiguities(CTParser) + # Full Aqua test suite - now with zero type piracy! + # All operations dispatch on our wrapper types (SymNumber, SymVector, SymMatrix) + # instead of on ExaModels.AbstractNode, completely eliminating type piracy. + Aqua.test_all(CTParser) end end diff --git a/test/test_exa_linalg.jl b/test/test_exa_linalg.jl new file mode 100644 index 0000000..4b5e665 --- /dev/null +++ b/test/test_exa_linalg.jl @@ -0,0 +1,349 @@ +# test_exa_linalg + +function test_exa_linalg() + + # Setup: Create symbolic variables for testing + c = ExaModels.ExaCore() + X = ExaModels.variable(c, 5, 4) + + # Wrap with new SymVector/SymMatrix types + x = CTParser.SymVector([X[i, 1] for i in 1:5]) + y = CTParser.SymVector([X[i, 2] for i in 1:5]) + M = CTParser.SymMatrix([X[i, j] for i in 1:3, j in 1:3]) + N = CTParser.SymMatrix([X[i, j] for i in 1:2, j in 1:4]) + + A = randn(5, 5) + B = randn(3, 3) + C = randn(2, 2) + v = randn(5) + w = randn(3) + + @testset "Basic SymNumber properties" begin + println(" Testing basic SymNumber properties") + + # x[1] now returns SymNumber + @test x[1] isa CTParser.SymNumber + + # zero and one return SymNumber + @test zero(CTParser.SymNumber) isa CTParser.SymNumber + @test zero(x[1]) isa CTParser.SymNumber + @test one(CTParser.SymNumber) isa CTParser.SymNumber + @test one(x[1]) isa CTParser.SymNumber + + # Unwrapped values are Null + @test CTParser.unwrap_scalar(zero(CTParser.SymNumber)) isa ExaModels.Null + @test CTParser.unwrap_scalar(one(CTParser.SymNumber)) isa ExaModels.Null + + # Scalar properties + @test length(x[1]) == 1 + @test size(x[1]) == () + @test ndims(x[1]) == 0 + @test ndims(typeof(x[1])) == 0 + + # adjoint/transpose/conj for scalars + @test adjoint(x[1]) === x[1] + @test transpose(x[1]) === x[1] + @test conj(x[1]) === x[1] + + # broadcastable + @test Base.broadcastable(x[1]) isa Base.RefValue + + # iterate should return nothing (not iterable) + @test iterate(x[1]) === nothing + end + + @testset "Type promotion and conversion" begin + println(" Testing type promotion and conversion") + + @test convert(CTParser.SymNumber, 5.0) isa CTParser.SymNumber + @test convert(CTParser.SymNumber, x[1]) === x[1] + + # Promotion with numbers + @test promote_type(typeof(x[1]), Float64) == CTParser.SymNumber + @test promote_type(typeof(x[1]), Int) == CTParser.SymNumber + end + + @testset "Symbolic arithmetic helpers" begin + println(" Testing sym_add and sym_mul") + + # sym_add with Null(nothing) - wrapping in SymNumber + null_zero = CTParser.SymNumber(ExaModels.Null(nothing)) + @test CTParser.sym_add(null_zero, x[1]) === x[1] + @test CTParser.sym_add(x[1], null_zero) === x[1] + @test CTParser.sym_add(x[1], y[1]) isa CTParser.SymNumber + + # sym_mul + @test CTParser.sym_mul(x[1], y[1]) isa CTParser.SymNumber + @test CTParser.sym_mul(2.0, x[1]) isa CTParser.SymNumber + end + + @testset "Matrix-vector products" begin + println(" Testing matrix-vector products") + + # Numeric matrix × Symbolic vector + @test A * x isa CTParser.SymVector + @test length(A * x) == size(A, 1) + + # Symbolic matrix × Numeric vector + @test M * v[1:3] isa CTParser.SymVector + @test length(M * v[1:3]) == size(M, 1) + + # Symbolic matrix × Symbolic vector + @test M * x[1:3] isa CTParser.SymVector + @test length(M * x[1:3]) == size(M, 1) + + # Different sizes + @test B * x[1:3] isa CTParser.SymVector + @test length(B * x[1:3]) == 3 + end + + @testset "Row vector × Matrix (via adjoint)" begin + println(" Testing row vector × matrix") + + # Symbolic row × Numeric matrix + result = x' * A + @test result isa LinearAlgebra.Adjoint + @test size(parent(result)) == (size(A, 2),) + + # Numeric row × Symbolic matrix + result = v[1:3]' * M + @test result isa LinearAlgebra.Adjoint + @test size(parent(result)) == (size(M, 2),) + + # Symbolic row × Symbolic matrix + result = x[1:3]' * M + @test result isa LinearAlgebra.Adjoint + @test size(parent(result)) == (size(M, 2),) + end + + @testset "Matrix × Matrix products" begin + println(" Testing matrix × matrix products") + + # Numeric × Symbolic + @test B * M isa CTParser.SymMatrix + @test size(B * M) == (size(B, 1), size(M, 2)) + + # Symbolic × Numeric + @test M * B isa CTParser.SymMatrix + @test size(M * B) == (size(M, 1), size(B, 2)) + + # Symbolic × Symbolic + @test M * M isa CTParser.SymMatrix + @test size(M * M) == (size(M, 1), size(M, 2)) + + # Different dimensions + M2x3 = CTParser.SymMatrix([X[i, j] for i in 1:2, j in 1:3]) + M3x4 = CTParser.SymMatrix([X[i, j] for i in 1:3, j in 1:4]) + @test M2x3 * M3x4 isa CTParser.SymMatrix + @test size(M2x3 * M3x4) == (2, 4) + end + + @testset "Dot products" begin + println(" Testing dot products") + + # Symbolic · Symbolic - now returns SymNumber + @test dot(x, y) isa CTParser.SymNumber + @test dot(x[1:3], y[1:3]) isa CTParser.SymNumber + + # Numeric · Symbolic + @test dot(v, x) isa CTParser.SymNumber + @test dot(v[1:3], x[1:3]) isa CTParser.SymNumber + + # Symbolic · Numeric + @test dot(x, v) isa CTParser.SymNumber + @test dot(x[1:3], v[1:3]) isa CTParser.SymNumber + end + + @testset "Inner products via adjoint" begin + println(" Testing inner products (x' * y)") + + # Symbolic' × Symbolic - now returns SymNumber + @test x' * y isa CTParser.SymNumber + @test x' * x isa CTParser.SymNumber + + # Symbolic' × Numeric + @test x' * v isa CTParser.SymNumber + @test x[1:3]' * v[1:3] isa CTParser.SymNumber + + # Numeric' × Symbolic + @test v' * x isa CTParser.SymNumber + @test v[1:3]' * x[1:3] isa CTParser.SymNumber + end + + @testset "Quadratic forms" begin + println(" Testing quadratic forms") + + # x' * A * x - now returns SymNumber + @test x' * A * x isa CTParser.SymNumber + @test x[1:3]' * M * x[1:3] isa CTParser.SymNumber + @test x[1:3]' * B * x[1:3] isa CTParser.SymNumber + end + + @testset "Matrix transpose and adjoint" begin + println(" Testing matrix transpose and adjoint") + + # Transpose + @test M' isa CTParser.SymMatrix + @test size(M') == (size(M, 2), size(M, 1)) + @test transpose(M) isa CTParser.SymMatrix + @test size(transpose(M)) == (size(M, 2), size(M, 1)) + + # Non-square matrix + @test N' isa CTParser.SymMatrix + @test size(N') == (size(N, 2), size(N, 1)) + + # Transpose should work in products + @test M' * x[1:3] isa CTParser.SymVector + @test length(M' * x[1:3]) == size(M, 2) + end + + @testset "Vector norms" begin + println(" Testing vector norms") + + # Default norm (L2) - now returns SymNumber + @test norm(x) isa CTParser.SymNumber + @test norm(y) isa CTParser.SymNumber + + # norm_sqr + @test norm_sqr(x) isa CTParser.SymNumber + @test norm_sqr(x[1:3]) isa CTParser.SymNumber + + # Explicit L2 norm + @test norm(x, 2) isa CTParser.SymNumber + + # L1 norm + @test norm(x, 1) isa CTParser.SymNumber + @test norm(x[1:3], 1) isa CTParser.SymNumber + + # Lp norms + @test norm(x, 3) isa CTParser.SymNumber + @test norm(x, 4) isa CTParser.SymNumber + end + + @testset "Matrix norms" begin + println(" Testing matrix norms") + + # Frobenius norm (default) - now returns SymNumber + @test norm(M) isa CTParser.SymNumber + @test norm(N) isa CTParser.SymNumber + end + + @testset "Broadcasting - unary operations" begin + println(" Testing broadcasting with unary operations") + + # Unary functions on vectors + @test sin.(x) isa CTParser.SymVector + @test length(sin.(x)) == length(x) + @test cos.(x) isa CTParser.SymVector + @test exp.(x) isa CTParser.SymVector + @test log.(x) isa CTParser.SymVector + @test sqrt.(x) isa CTParser.SymVector + @test abs.(x) isa CTParser.SymVector + + # Unary functions on matrices + @test sin.(M) isa CTParser.SymMatrix + @test size(sin.(M)) == size(M) + @test cos.(M) isa CTParser.SymMatrix + @test exp.(M) isa CTParser.SymMatrix + @test log.(M) isa CTParser.SymMatrix + end + + @testset "Broadcasting - binary operations" begin + println(" Testing broadcasting with binary operations") + + # Element-wise arithmetic on vectors + @test x .+ y isa CTParser.SymVector + @test length(x .+ y) == length(x) + @test x .- y isa CTParser.SymVector + @test x .* y isa CTParser.SymVector + @test x ./ y isa CTParser.SymVector + @test x .^ 2 isa CTParser.SymVector + @test x .^ y isa CTParser.SymVector + + # Element-wise with scalars + @test 2.0 .* x isa CTParser.SymVector + @test x .+ 1.0 isa CTParser.SymVector + @test x .- 3.0 isa CTParser.SymVector + @test x ./ 2.0 isa CTParser.SymVector + + # Element-wise on matrices + @test M .+ M isa CTParser.SymMatrix + @test size(M .+ M) == size(M) + @test M .- M isa CTParser.SymMatrix + @test M .* M isa CTParser.SymMatrix + @test 2.0 .* M isa CTParser.SymMatrix + @test M ./ 2.0 isa CTParser.SymMatrix + end + + @testset "Broadcasting - compound expressions" begin + println(" Testing broadcasting with compound expressions") + + # Compound vector expressions + @test exp.(x) .+ 1 isa CTParser.SymVector + @test sin.(x) .* cos.(y) isa CTParser.SymVector + @test (x .+ y) ./ 2 isa CTParser.SymVector + @test sqrt.(x.^2 .+ y.^2) isa CTParser.SymVector + + # Compound matrix expressions + @test sin.(M) .+ cos.(M) isa CTParser.SymMatrix + @test (M .+ M) ./ 2 isa CTParser.SymMatrix + @test exp.(M) .* 2.0 isa CTParser.SymMatrix + end + + @testset "abs and abs2" begin + println(" Testing abs and abs2") + + # abs2 for scalar node - now returns SymNumber + @test abs2(x[1]) isa CTParser.SymNumber + + # abs2 via broadcasting + @test abs2.(x) isa CTParser.SymVector + @test length(abs2.(x)) == length(x) + end + + @testset "Edge cases and dimension mismatches" begin + println(" Testing edge cases") + + # Test that dimension mismatches are caught + A_wrong = randn(3, 4) + @test_throws AssertionError A_wrong * x # 4 != 5 + + @test_throws AssertionError M * v # 3 != 5 + + # Different length dot products + @test_throws AssertionError dot(x[1:3], y[1:4]) + + # Matrix dimension mismatch + M2x3 = CTParser.SymMatrix([X[i, j] for i in 1:2, j in 1:3]) + M4x2 = CTParser.SymMatrix([X[i, j] for i in 1:4, j in 1:2]) + @test_throws AssertionError M2x3 * M4x2 # 3 != 4 + end + + @testset "Mixed operations" begin + println(" Testing mixed operations") + + # Combine different operations + result = A * x + v + @test result isa CTParser.SymVector + @test length(result) == length(v) + + # Matrix chain + result = M * M * M + @test result isa CTParser.SymMatrix + @test size(result) == size(M) + + # Complex expression - now returns SymNumber + result = (x' * A * x) + dot(x, v) + norm(x)^2 + @test result isa CTParser.SymNumber + + # Broadcasting mixed with products + result = (A * x) .+ v + @test result isa CTParser.SymVector + + # Norm of matrix-vector product - now returns SymNumber + result = norm(M * x[1:3]) + @test result isa CTParser.SymNumber + end + + println(" All exa_linalg tests passed!") +end diff --git a/test/test_onepass_exa.jl b/test/test_onepass_exa.jl index 19c9e4f..086e9d2 100644 --- a/test/test_onepass_exa.jl +++ b/test/test_onepass_exa.jl @@ -3,7 +3,8 @@ activate_backend(:exa) # nota bene: needs to be executed before @def are expanded -# mock up of CTDirect.discretise for tests +# Mock up of CTDirect.discretise for tests + function discretise_exa( ocp; scheme=CTParser.__default_scheme_exa(), @@ -32,18 +33,14 @@ function discretise_exa_full( ) end +# Tests + function test_onepass_exa() - __test_onepass_exa(; scheme=:euler) - __test_onepass_exa(; scheme=:euler_implicit) - __test_onepass_exa(; scheme=:midpoint) - __test_onepass_exa(; scheme=:trapeze) - if CUDA.functional() - __test_onepass_exa(CUDABackend(); scheme=:euler) - __test_onepass_exa(CUDABackend(); scheme=:euler_implicit) - __test_onepass_exa(CUDABackend(); scheme=:midpoint) - __test_onepass_exa(CUDABackend(); scheme=:trapeze) - else - println("********** CUDA not available") + #l_scheme = [:euler, :euler_implicit, :midpoint, :trapeze] + l_scheme = [:midpoint] # debug + for scheme ∈ l_scheme + __test_onepass_exa(; scheme=scheme) + CUDA.functional() && __test_onepass_exa(CUDABackend(); scheme=scheme) end end @@ -52,6 +49,7 @@ function __test_onepass_exa( ) backend_name = isnothing(backend) ? "CPU" : "GPU" + @ignore begin # debug test_name = "min ($backend_name, $scheme)" @testset "$test_name" begin println(test_name) @@ -98,10 +96,542 @@ function __test_onepass_exa( @test CTParser.is_range(1:2) == true @test CTParser.is_range(1:2:5) == true @test CTParser.is_range(:(x:y:z)) == true - @test CTParser.as_range(1) == [1] + @test CTParser.as_range(1) == :((1):(1)) @test CTParser.as_range(1:2) == 1:2 - @test CTParser.as_range(:x) == [:x] - @test CTParser.as_range(:(x + 1)) == [:(x + 1)] + @test CTParser.as_range(:x) == :(x:x) + @test CTParser.as_range(:(x + 1)) == :((x + 1):(x + 1)) + end + + test_name = "bare symbols and ranges - costs ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Test: Lagrange with sum over all state components + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + x(0) == [1, 2, 3] + x(1) == [4, 5, 6] + ∫(sum(x(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Lagrange with sum over range of states + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + ∫(sum(x[1:2](t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Lagrange with sum over all controls + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∫(sum(u(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Lagrange with sum over range of controls + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∫(sum(u[1:2](t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over all states at t0 + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(0))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over all states at tf + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(1))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over range at t0 + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](0))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over range at tf + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[2:3](1))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Bolza cost with bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + (sum(x(0))^2 + sum(x(1))^2) + ∫(sum(u(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Bolza cost with ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + (sum(x[1:2](0)) + sum(x[2:3](1))) + ∫(sum(u[1:2](t))) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + end + + test_name = "bare symbols and ranges - constraints ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Test: Initial constraint with bare symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(0)) == 6 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Initial constraint with range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](0)) == 3 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Final constraint with bare symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(1)) == 15 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Final constraint with range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[2:3](1)) == 11 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Boundary constraint combining t0 and tf with bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(0)) + sum(x(1)) == 21 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Boundary constraint with ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](0)) - sum(x[2:3](1)) == -8 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with bare state symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(t))^2 ≤ 100 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with state range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](t)) ≤ 10 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with bare control symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + sum(u(t))^2 ≤ 5 + ∫(x₁(t)^2 + x₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with control range + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + sum(u[1:2](t)) ≤ 3 + ∫(x₁(t)^2 + x₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mixed constraint with bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(t)) + sum(u(t)) ≤ 15 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mixed constraint with ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₃(t) + sum(x[1:2](t)) + sum(u[2:3](t)) ≤ 8 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + end + + test_name = "bare symbols and ranges - dynamics ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Test: Dynamics with sum over all states + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == sum(x(t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with sum over state range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == sum(x[2:3](t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with sum over all controls + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == sum(u(t)) + ∂(x₂)(t) == u₁(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with sum over control range + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == sum(u[1:2](t)) + ∂(x₂)(t) == u₃(t) + ∫(u₁(t)^2 + u₂(t)^2 + u₃(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with mixed bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == sum(x(t)) + sum(u(t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with mixed ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R³, control + ∂(x₁)(t) == sum(x[1:2](t)) + sum(u[2:3](t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2 + u₃(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + end + + test_name = "user-defined functions with ranges ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Define user functions outside @def + f(x, u) = x[1] * x[3] + u[1]^2 * cos(u[2]) + g(x) = x[1] + 2 * x[2] + h(u) = u[1]^2 + sin(u[2]) + + # Test: User-defined function in Lagrange cost + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + ∫(f(x(t), u(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in Mayer cost at t0 + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(0), [0, 0])^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in Mayer cost at tf + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(1), [0, 0])^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in Bolza cost + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + (f(x(0), [0, 0]) + f(x(1), [0, 0])) + ∫(h(u(t))) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in initial constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(0), [0, 0]) == 5 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in final constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(1), [0, 0]) == 10 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in boundary constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(0), [0, 0]) + f(x(1), [0, 0]) == 15 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in path constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(t), u(t)) ≤ 10 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in dynamics + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == g(x[1:2](t)) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Multiple user-defined functions + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == g(x[1:2](t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + h(u(t)) ≤ 5 + (f(x(0), [0, 0])) + ∫(f(x(t), u(t))) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel end test_name = "pragma ($backend_name, $scheme)" @@ -424,10 +954,22 @@ function __test_onepass_exa( x ∈ R⁴, state u ∈ R⁵, control v ≤ [1, 2, 3] + v ≥ [1, 2, 3] + v[1] ≤ 1 + v[1] ≥ 1 + v[1:2] ≤ [1, 2] v[1:2] ≥ [1, 2] + x[2](t) ≤ 1 + x[2:2:4](t) ≤ [1, 2] + x[2:4](t) ≤ [1, 2, 3] + x[2](t) ≥ 1 + x[2:2:4](t) ≥ [1, 2] + x[2:4](t) ≥ [1, 2, 3] + u[2](t) ≤ 1 u[2:2:4](t) ≤ [1, 2] - u[2:4](t) ≥ [1, 2, 3] - u[2:2:4](t) ≤ [1, 2] + u[2:4](t) ≤ [1, 2, 3] + u[2](t) ≥ 1 + u[2:2:4](t) ≥ [1, 2] u[2:4](t) ≥ [1, 2, 3] ∂(x₁)(t) == x₁(t) ∂(x₂)(t) == x₁(t) @@ -1005,4 +1547,291 @@ function __test_onepass_exa( sol = madnlp(m; tol=tolerance, kwargs...) @test sol.status == MadNLP.SOLVE_SUCCEEDED end + + test_name = "use case no. 4: vectorised ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + f₁(x, u) = 2x[1] * u[1] + x[2] * u[2] + f₂(x) = x[1] + 2x[2] - x[3] + f₃(x0, xf) = x0[2]^2 + sum(xf)^2 + f₄(u) = sum(u.^2) + + o1 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + x[1:2:3](0) == [1, 3] + + ∂(x₁)(t) == sum(x(t)) + ∂(x₂)(t) == f₁(x(t), u(t)) + ∂(x₃)(t) == f₂(x(t)) + + f₃(x(0), x(1)) + 0.5∫( f₄(u(t)) ) → min + end + + N = 250 + max_iter = 10 + m1, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m1 isa ExaModels.ExaModel + sol1 = madnlp(m1; tol=tolerance, max_iter=max_iter, kwargs...) + obj1 = sol1.objective + + o2 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + x[1:2:3](0) == [1, 3] + + ∂(x₁)(t) == x₁(t) + x₂(t) + x₃(t) + ∂(x₂)(t) == 2x₁(t) * u₁(t) + x₂(t) * u₂(t) + ∂(x₃)(t) == x₁(t) + 2x₂(t) - x₃(t) + + (x₂(0)^2 + (x₁(1) + x₂(1) + x₃(1))^2) + 0.5∫( u₁(t)^2 + u₂(t)^2 ) → min + end + + m2, _ = discretise_exa_full(o2; grid_size=N, backend=backend, scheme=scheme) + @test m2 isa ExaModels.ExaModel + sol2 = madnlp(m2; tol=tolerance, max_iter=max_iter, kwargs...) + obj2 = sol2.objective + + __atol = 1e-9 + @test obj1 - obj2 ≈ 0 atol = __atol + end + + test_name = "use case no. 5: vectorised with ranges ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + g₁(x) = x[1]^2 + x[2]^2 + g₂(u) = u[1] * u[2] + + # Vectorised version using ranges + o1 = @def begin + t ∈ [0, 1], time + x ∈ R⁴, state + u ∈ R², control + + x(0) == [0, .1, .2, .3] + + ∂(x₁)(t) == g₁(x[1:2](t)) + ∂(x₂)(t) == g₂(u(t)) + ∂(x₃)(t) == sum(x[2:4](t)) + ∂(x₄)(t) == u₁(t) + + sum(x[1:3](1))^2 + 0.5∫( sum(u(t).^2) ) → min + end + + N = 250 + max_iter = 10 + m1, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m1 isa ExaModels.ExaModel + sol1 = madnlp(m1; tol=tolerance, max_iter=max_iter, kwargs...) + obj1 = sol1.objective + + # Non-vectorised version using subscripts + o2 = @def begin + t ∈ [0, 1], time + x ∈ R⁴, state + u ∈ R², control + + x(0) == [0, .1, .2, .3] + + ∂(x₁)(t) == x₁(t)^2 + x₂(t)^2 + ∂(x₂)(t) == u₁(t) * u₂(t) + ∂(x₃)(t) == x₂(t) + x₃(t) + x₄(t) + ∂(x₄)(t) == u₁(t) + + (x₁(1) + x₂(1) + x₃(1))^2 + 0.5∫( u₁(t)^2 + u₂(t)^2 ) → min + end + + m2, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m2 isa ExaModels.ExaModel + sol2 = madnlp(m2; tol=tolerance, max_iter=max_iter, kwargs...) + obj2 = sol2.objective + + __atol = 1e-9 + @test obj1 - obj2 ≈ 0 atol = __atol + end + + test_name = "use case no. 6: vectorised constraints ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + h₁(x) = x[1] + 2x[2] + 3x[3] + h₂(u) = u[1]^2 + u[2]^2 + + # Vectorised version + o1 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + sum(x(0).^2) == 1.5 + h₁(x(1)) ≤ 200 + + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == sum(u(t)) + + h₂(u(t)) ≤ 10 + + sum(x(1))^2 + ∫( h₂(u(t)) ) → min + end + + N = 250 + max_iter = 10 + m1, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m1 isa ExaModels.ExaModel + sol1 = madnlp(m1; tol=tolerance, max_iter=max_iter, kwargs...) + obj1 = sol1.objective + + # Non-vectorised version + o2 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + x₁(0)^2 + x₂(0)^2 + x₃(0)^2 == 1.5 + x₁(1) + 2x₂(1) + 3x₃(1) ≤ 200 + + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + + u₁(t)^2 + u₂(t)^2 ≤ 10 + + (x₁(1) + x₂(1) + x₃(1))^2 + ∫( u₁(t)^2 + u₂(t)^2 ) → min + end + + m2, _ = discretise_exa_full(o2; grid_size=N, backend=backend, scheme=scheme) + @test m2 isa ExaModels.ExaModel + sol2 = madnlp(m2; tol=tolerance, max_iter=max_iter, kwargs...) + obj2 = sol2.objective + + __atol = 1e-9 + @test obj1 - obj2 ≈ 0 atol = __atol + end + + # todo: test below inactived on GPU because run is unstable + if isnothing(backend) test_name = "use case no. 7: mixed vectorisation ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # User-defined functions + p₁(x, u) = x[1] * u[1] + x[2] * u[2] + p₂(x) = x[1]^2 + x[2]^2 + x[3]^2 + + # Vectorised version with mixed patterns + o1 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + x[1:2](0) == [0, 0.1] + -0.1 ≤ x₃(0) ≤ 0.1 + sum(x(1)) == 0.2 + + ∂(x₁)(t) == p₁(x[1:2](t), u(t)) + ∂(x₂)(t) == sum(u(t)) + ∂(x₃)(t) == x₁(t) + + p₂(x(t)) ≤ 50 + + (p₂(x(0)) + sum(x[1:2](1))^2) + 0.5∫( sum(u(t).^2) ) → min + end + + N = 250 + max_iter = 10 + m1, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m1 isa ExaModels.ExaModel + sol1 = madnlp(m1; tol=tolerance, max_iter=max_iter, kwargs...) + obj1 = sol1.objective + + # Non-vectorised version + o2 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + x₁(0) == 0 + x₂(0) == 0.1 + -0.1 ≤ x₃(0) ≤ 0.1 + x₁(1) + x₂(1) + x₃(1) == 0.2 + + ∂(x₁)(t) == x₁(t) * u₁(t) + x₂(t) * u₂(t) + ∂(x₂)(t) == u₁(t) + u₂(t) + ∂(x₃)(t) == x₁(t) + + x₁(t)^2 + x₂(t)^2 + x₃(t)^2 ≤ 50 + + (x₁(0)^2 + x₂(0)^2 + x₃(0)^2 + (x₁(1) + x₂(1))^2) + 0.5∫( u₁(t)^2 + u₂(t)^2 ) → min + end + + m2, _ = discretise_exa_full(o2; grid_size=N, backend=backend, scheme=scheme) + @test m2 isa ExaModels.ExaModel + sol2 = madnlp(m2; tol=tolerance, max_iter=max_iter, kwargs...) + obj2 = sol2.objective + + __atol = 1e-9 + @test obj1 - obj2 ≈ 0 atol = __atol + end end + end # debug + + test_name = "use case no. 8: vectorised dynamics ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + tf = 5 + x0 = [0, 1] + A = [0 1; -1 0] + B = [0, 1] + Q = [1 0; 0 1] + R = 1 + + # Vectorised version + o1 = @def begin + t ∈ [0, tf], time + x ∈ R², state + u ∈ R, control + x(0) == x0 + #∂(x₁)(t) == dot(A[1, :], x(t)) + u(t) * B[1] + ∂(x₁)(t) == x₂(t) + #∂(x₂)(t) == dot(A[2, :], x(t)) + u(t) * B[2] + ∂(x₂)(t) == -x₁(t) + u(t) + #0.5∫( x(t)' * Q * x(t) + u(t)' * R * u(t) ) → min + 0.5∫( x₁(t)^2 + x₂(t)^2 + u(t)^2 ) → min + end + + N = 250 + max_iter = 10 + m1, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m1 isa ExaModels.ExaModel + sol1 = madnlp(m1; tol=tolerance, max_iter=max_iter, kwargs...) + @test sol1.status == MadNLP.SOLVE_SUCCEEDED + obj1 = sol1.objective + + @ignore begin #debug + # Non-vectorised version + o2 = @def begin + t ∈ [0, tf], time + x ∈ R², state + u ∈ R, control + x(0) == x0 + ∂(x₁)(t) == x₂(t) + ∂(x₂)(t) == -x₁(t) + u(t) + 0.5∫( x₁(t)^2 + x₂(t)^2 + u(t)^2 ) → min + end + + m2, _ = discretise_exa_full(o2; grid_size=N, backend=backend, scheme=scheme) + @test m2 isa ExaModels.ExaModel + sol2 = madnlp(m2; tol=tolerance, max_iter=max_iter, kwargs...) + obj2 = sol2.objective + + __atol = 1e-9 + @test obj1 - obj2 ≈ 0 atol = __atol + end # debug + end end diff --git a/test/test_onepass_exa_bis.jl b/test/test_onepass_exa_bis.jl index ee9ff08..2e67b91 100644 --- a/test/test_onepass_exa_bis.jl +++ b/test/test_onepass_exa_bis.jl @@ -1,4 +1,5 @@ # test_onepass_exa_bis +# todo: merge with test_onepass_exa function test_onepass_exa_bis() @testset "p_dynamics_exa! errors" begin @@ -44,16 +45,16 @@ function test_onepass_exa_bis() @test CTParser.is_range(:(a:b)) @test CTParser.as_range(:(a:b:c)) == :(a:b:c) - # Fallback to single-element vector when not a range - @test CTParser.as_range(:foo) == [:foo] + # Fallback to single-element range expression when not a range + @test CTParser.as_range(:foo) == :(foo:foo) # Edge cases @test !CTParser.is_range(42) @test !CTParser.is_range("string") @test !CTParser.is_range(:(f(x))) @test CTParser.is_range(1:10) - @test CTParser.as_range(42) == [42] - @test CTParser.as_range(:(x + y)) == [:(x + y)] + @test CTParser.as_range(42) == :((42):(42)) + @test CTParser.as_range(:(x + y)) == :((x + y):(x + y)) end @testset "p_dynamics_coord_exa! preconditions" begin @@ -506,4 +507,138 @@ function test_onepass_exa_bis() ex = CTParser.p_constraint_exa!(p, p_ocp, 0, :(x[1](0) + v), 1, :variable_fun, :c1) @test ex isa Expr end + + # ============================================================================ + # P_CONSTRAINT_EXA! - :other constraint type error (invalid constraints) + # ============================================================================ + + @testset "p_constraint_exa! :other constraint type error" begin + println("p_constraint_exa! :other type (bis)") + + p = CTParser.ParsingInfo() + p.lnum = 1 + p.line = "constraint exa :other test" + p.t = :t + p.t0 = 0 + p.tf = 1 + p.x = :x + p.u = :u + p.v = :v + p.dt = :dt + p_ocp = :p_ocp + + # Constraint with :other type should raise an error + # This simulates what happens when constraint_type returns :other + ex = CTParser.p_constraint_exa!(p, p_ocp, 0, :(x[1](0) + u[1](t)), 1, :other, :c1) + @test ex isa Expr + @test_throws ParsingError eval(ex) + + # Another :other case - the exact example from the user + ex2 = CTParser.p_constraint_exa!( + p, p_ocp, nothing, :(x[1](0) * u[1](t) + u[2](t)^2), 1, :other, :c2 + ) + @test ex2 isa Expr + @test_throws ParsingError eval(ex2) + end + + @testset "p_constraint! detects :other constraint type (exa)" begin + println("p_constraint! detects :other for exa (bis)") + + p = CTParser.ParsingInfo() + p.lnum = 1 + p.line = "constraint detection test exa" + p.t = :t + p.t0 = 0 + p.tf = 1 + p.x = :x + p.u = :u + p.v = :v + p.dim_x = 2 + p.dim_u = 2 + p.dt = :dt + p_ocp = :p_ocp + + # Test that p_constraint! correctly identifies invalid constraints + # Mixed initial state and control: x1(0) * u1(t) + u2(t)^2 <= 1 + # This should result in constraint_type returning :other + ex = CTParser.p_constraint!( + p, p_ocp, nothing, :(x[1](0) * u[1](t) + u[2](t)^2), 1; backend=:exa + ) + @test ex isa Expr + @test_throws ParsingError eval(ex) + + # Mixed final state and control at initial time: x(tf) + u(t0) + # This should also result in :other + ex2 = CTParser.p_constraint!(p, p_ocp, 0, :(x[1](tf) + u[1](0)), 1; backend=:exa) + @test ex2 isa Expr + @test_throws ParsingError eval(ex2) + + # Control at initial and final time: u(t0) + u(tf) + ex3 = CTParser.p_constraint!(p, p_ocp, 0, :(u[1](0) + u[1](tf)), 1; backend=:exa) + @test ex3 isa Expr + @test_throws ParsingError eval(ex3) + end + + # ============================================================================ + # @def_exa MACRO - Invalid constraints that should raise ParsingError + # ============================================================================ + + @testset "@def_exa macro :other constraint type error" begin + println("@def_exa macro :other constraint (bis)") + + backend = nothing + + # Test 1: Mixed initial state and control - x1(0) * u1(t) + u2(t)^2 <= 1 + # This should trigger constraint_type to return :other and raise ParsingError + o = @def_exa begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R², control + x[1](0) * u[1](t) + u[2](t)^2 ≤ 1 + ẋ₁(t) == u[1](t) + ẋ₂(t) == u[2](t) + end + @test_throws ParsingError o(; backend=backend) + + # Test 2: Mixed final state and control at initial time - x(tf) + u(t0) + o = @def_exa begin + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + x(1) + u(0) ≤ 1 + ẋ(t) == u(t) + end + @test_throws ParsingError o(; backend=backend) + + # Test 3: Control at both initial and final time - u(t0) + u(tf) + o = @def_exa begin + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + u(0) + u(1) ≤ 1 + ẋ(t) == u(t) + end + @test_throws ParsingError o(; backend=backend) + + # Test 4: Another invalid mixing - state at t and control at t0 + o = @def_exa begin + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + x(t) + u(0) ≤ 1 + ẋ(t) == u(t) + end + @test_throws ParsingError o(; backend=backend) + + # Test 5: The exact user example - x1(0) * u1(t) + u2(t)^2 <= 1 + o = @def_exa begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R², control + x₁(0) * u₁(t) + u₂(t)^2 ≤ 1 + ẋ₁(t) == u₁(t) + ẋ₂(t) == u₂(t) + end + @test_throws ParsingError o(; backend=backend) + end end diff --git a/test/test_onepass_fun.jl b/test/test_onepass_fun.jl index e7b37ba..ab08048 100644 --- a/test/test_onepass_fun.jl +++ b/test/test_onepass_fun.jl @@ -757,7 +757,7 @@ function test_onepass_fun() r = y₃ v = y₄ aa = y₁(__s) - ẏ(__s) == [aa(__s), r²(__s) + w(__s) + z₁, 0, 0] + ẏ(__s) == [aa(__s), (r^2)(__s) + w(__s) + z₁, 0, 0] 0 => min # generic (untested) end z = [5, 6] @@ -851,7 +851,7 @@ function test_onepass_fun() v = y₄ aa = y₁(__s) ∂(y[1])(__s) == aa(__s) - ∂(y[2])(__s) == r²(__s) + w(__s) + z₁ + ∂(y[2])(__s) == (r^2)(__s) + w(__s) + z₁ ∂(y[3])(__s) == 0 ∂(y[4])(__s) == 0 0 => min # generic (untested) @@ -1241,10 +1241,10 @@ function test_onepass_fun() x(0) ≤ 0, (1) x(1) ≤ 0 x(1) ≤ 0, (2) - x³(0) ≤ 0 - x³(0) ≤ 0, (3) - x³(1) ≤ 0 - x³(1) ≤ 0, (4) + (x^3)(0) ≤ 0 + (x^3)(0) ≤ 0, (3) + (x^3)(1) ≤ 0 + (x^3)(1) ≤ 0, (4) x(t) ≤ 0 x(t) ≤ 0, (5) x(t) ≤ 0 @@ -1253,10 +1253,10 @@ function test_onepass_fun() u₁(t) ≤ 0, (7) u₁(t) ≤ 0 u₁(t) ≤ 0, (8) - x³(t) ≤ 0 - x³(t) ≤ 0, (9) - x³(t) ≤ 0 - x³(t) ≤ 0, (10) + x(t) ≤ 0 + (x^3)(t) ≤ 0, (9) + (x^3)(t) ≤ 0 + (x^3)(t) ≤ 0, (10) (u₁^3)(t) ≤ 0 (u₁^3)(t) ≤ 0, (11) (u₁^3)(t) ≤ 0 @@ -1311,10 +1311,10 @@ function test_onepass_fun() x(0) ≥ 0, (1) x(1) ≥ 0 x(1) ≥ 0, (2) - x³(0) ≥ 0 - x³(0) ≥ 0, (3) - x³(1) ≥ 0 - x³(1) ≥ 0, (4) + (x^3)(0) ≥ 0 + (x^3)(0) ≥ 0, (3) + (x^3)(1) ≥ 0 + (x^3)(1) ≥ 0, (4) x(t) ≥ 0 x(t) ≥ 0, (5) x(t) ≥ 0 @@ -1323,10 +1323,10 @@ function test_onepass_fun() u₁(t) ≥ 0, (7) u₁(t) ≥ 0 u₁(t) ≥ 0, (8) - x³(t) ≥ 0 - x³(t) ≥ 0, (9) - x³(t) ≥ 0 - x³(t) ≥ 0, (10) + (x^3)(t) ≥ 0 + (x^3)(t) ≥ 0, (9) + (x^3)(t) ≥ 0 + (x^3)(t) ≥ 0, (10) (u₁^3)(t) ≥ 0 (u₁^3)(t) ≥ 0, (11) (u₁^3)(t) ≥ 0 diff --git a/test/test_onepass_fun_bis.jl b/test/test_onepass_fun_bis.jl index c4500de..d9a7e22 100644 --- a/test/test_onepass_fun_bis.jl +++ b/test/test_onepass_fun_bis.jl @@ -500,4 +500,129 @@ function test_onepass_fun_bis() @test ex2 isa Expr @test_throws ParsingError eval(ex2) end + + # ============================================================================ + # P_CONSTRAINT! - :other constraint type error (invalid constraints) + # ============================================================================ + + @testset "p_constraint_fun! :other constraint type error" begin + println("p_constraint_fun! :other type (bis)") + + p = CTParser.ParsingInfo() + p.lnum = 1 + p.line = "constraint :other test" + p.t = :t + p.t0 = 0 + p.tf = 1 + p.x = :x + p.u = :u + p.v = :v + p_ocp = :p_ocp + + # Constraint with :other type should raise an error + # This simulates what happens when constraint_type returns :other + ex = CTParser.p_constraint_fun!(p, p_ocp, 0, :(x[1](0) + u[1](t)), 1, :other, :c1) + @test ex isa Expr + @test_throws ParsingError eval(ex) + + # Another :other case + ex2 = CTParser.p_constraint_fun!( + p, p_ocp, nothing, :(x[1](0) * u[1](t) + u[2](t)^2), 1, :other, :c2 + ) + @test ex2 isa Expr + @test_throws ParsingError eval(ex2) + end + + @testset "p_constraint! detects :other constraint type" begin + println("p_constraint! detects :other (bis)") + + p = CTParser.ParsingInfo() + p.lnum = 1 + p.line = "constraint detection test" + p.t = :t + p.t0 = 0 + p.tf = 1 + p.x = :x + p.u = :u + p.v = :v + p.dim_x = 2 + p.dim_u = 2 + p_ocp = :p_ocp + + # Test that p_constraint! correctly identifies invalid constraints + # Mixed initial state and control: x1(0) * u1(t) + u2(t)^2 <= 1 + # This should result in constraint_type returning :other + ex = CTParser.p_constraint!( + p, p_ocp, nothing, :(x[1](0) * u[1](t) + u[2](t)^2), 1 + ) + @test ex isa Expr + @test_throws ParsingError eval(ex) + + # Mixed final state and control at initial time: x(tf) + u(t0) + # This should also result in :other + ex2 = CTParser.p_constraint!(p, p_ocp, 0, :(x[1](tf) + u[1](0)), 1) + @test ex2 isa Expr + @test_throws ParsingError eval(ex2) + + # Control at initial and final time: u(t0) + u(tf) + ex3 = CTParser.p_constraint!(p, p_ocp, 0, :(u[1](0) + u[1](tf)), 1) + @test ex3 isa Expr + @test_throws ParsingError eval(ex3) + end + + # ============================================================================ + # @def MACRO - Invalid constraints that should raise ParsingError + # ============================================================================ + + @testset "@def macro :other constraint type error" begin + println("@def macro :other constraint (bis)") + + # Test 1: Mixed initial state and control - x1(0) * u1(t) + u2(t)^2 <= 1 + # This should trigger constraint_type to return :other and raise ParsingError + @test_throws ParsingError @eval @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R², control + x[1](0) * u[1](t) + u[2](t)^2 ≤ 1 + x(0) == [0, 0] + x(1) == [1, 1] + ẋ(t) == [u[1](t), u[2](t)] + ∫(u[1](t)^2 + u[2](t)^2) → min + end + + # Test 2: Mixed final state and control at initial time - x(tf) + u(t0) + @test_throws ParsingError @eval @def begin + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + x(1) + u(0) ≤ 1 + x(0) == 0 + ẋ(t) == u(t) + ∫(u(t)^2) → min + end + + # Test 3: Control at both initial and final time - u(t0) + u(tf) + @test_throws ParsingError @eval @def begin + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + u(0) + u(1) ≤ 1 + x(0) == 0 + x(1) == 1 + ẋ(t) == u(t) + ∫(u(t)^2) → min + end + + # Test 4: Another invalid mixing - state at t and control at t0 + @test_throws ParsingError @eval @def begin + t ∈ [0, 1], time + x ∈ R, state + u ∈ R, control + x(t) + u(0) ≤ 1 + x(0) == 0 + x(1) == 1 + ẋ(t) == u(t) + ∫(u(t)^2) → min + end + end end diff --git a/test/test_usecase8_gpu.jl b/test/test_usecase8_gpu.jl new file mode 100644 index 0000000..248547a --- /dev/null +++ b/test/test_usecase8_gpu.jl @@ -0,0 +1,87 @@ +# Temporary test file for use case no. 8 with GPU backend only +# Testing inlined expressions (not using dot) to verify GPU error is in MadNLP + +activate_backend(:exa) + +# Mock up of CTDirect.discretise for tests +function discretise_exa_full( + ocp; + scheme=CTParser.__default_scheme_exa(), + grid_size=CTParser.__default_grid_size_exa(), + backend=CTParser.__default_backend_exa(), + init=CTParser.__default_init_exa(), + base_type=CTParser.__default_base_type_exa(), +) + build_exa = CTModels.get_build_examodel(ocp) + return build_exa(; + scheme=scheme, grid_size=grid_size, backend=backend, init=init, base_type=base_type + ) +end + +function test_usecase8_gpu() + if !CUDA.functional() + @test_skip "CUDA not functional, skipping GPU test" + return + end + + backend = CUDABackend() + backend_name = "GPU" + scheme = :midpoint + tolerance = 1e-8 + kwargs = () + + test_name = "use case no. 8: vectorised dynamics ($backend_name, $scheme) - INLINED EXPRESSIONS" + @testset "$test_name" begin + println(test_name) + + tf = 5 + x0 = [0, 1] + A = [0 1; -1 0] + B = [0, 1] + Q = [1 0; 0 1] + R = 1 + + # Vectorised version with INLINED expressions (not using dot) + # Original dot expressions are commented out to show the equivalence + o1 = @def begin + t ∈ [0, tf], time + x ∈ R², state + u ∈ R, control + x(0) == x0 + #∂(x₁)(t) == dot(A[1, :], x(t)) + u(t) * B[1] # = 0*x₁(t) + 1*x₂(t) = x₂(t) + ∂(x₁)(t) == x₂(t) + #∂(x₂)(t) == dot(A[2, :], x(t)) + u(t) * B[2] # = -1*x₁(t) + 0*x₂(t) + u(t)*1 + ∂(x₂)(t) == -x₁(t) + u(t) + #0.5∫( x(t)' * Q * x(t) + u(t)' * R * u(t) ) → min + 0.5∫( x₁(t)^2 + x₂(t)^2 + u(t)^2 ) → min + end + + N = 250 + max_iter = 10 + m1, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m1 isa ExaModels.ExaModel + + # This is where the GPU error should occur (in MadNLP, not in our code) + sol1 = madnlp(m1; tol=tolerance, max_iter=max_iter, kwargs...) + obj1 = sol1.objective + + # Non-vectorised version (for comparison) + o2 = @def begin + t ∈ [0, tf], time + x ∈ R², state + u ∈ R, control + x(0) == x0 + ∂(x₁)(t) == x₂(t) + ∂(x₂)(t) == -x₁(t) + u(t) + 0.5∫( x₁(t)^2 + x₂(t)^2 + u(t)^2 ) → min + end + + m2, _ = discretise_exa_full(o2; grid_size=N, backend=backend, scheme=scheme) + @test m2 isa ExaModels.ExaModel + sol2 = madnlp(m2; tol=tolerance, max_iter=max_iter, kwargs...) + obj2 = sol2.objective + + __atol = 1e-9 + @test obj1 - obj2 ≈ 0 atol = __atol + end +end diff --git a/test/test_utils.jl b/test/test_utils.jl index effa2d8..580a11e 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -27,13 +27,118 @@ function test_utils() @testset "subs2" begin println("subs2") - e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) - @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == - :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) + # ===== EXISTING FUNCTIONALITY (scalar indexing) ===== - e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) - @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == - :(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) + @testset "scalar indexing (existing)" begin + # Test 1: Basic scalar substitution + e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) + + # Test 2: Bare symbols are NOT substituted + e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) + + # Test 3: Numeric index + e = :(x0[5] + x0[10]) + @test subs2(e, :x0, :x, 0) == :(x[5, 0] + x[10, 0]) + + # Test 4: Symbolic index + e = :(x0[i] + x0[j]) + @test subs2(e, :x0, :x, 0) == :(x[i, 0] + x[j, 0]) + end + + # ===== NEW FUNCTIONALITY (range indexing) ===== + + @testset "range indexing (new)" begin + # Test 5: Simple range 1:3 + e = :(x0[1:3]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:3]) + + # Test 6: Range with step 1:2:5 + e = :(x0[1:2:5]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:2:5]) + + # Test 7: Range with symbolic bounds + e = :(x0[1:n]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:n]) + + # Test 8: Multiple ranges in same expression + e = :(x0[1:3] + xf[2:4]) + result = subs2(subs2(e, :x0, :x, 0; k = :k1), :xf, :x, :N; k = :k2) + @test result == :([x[k1, 0] for k1 ∈ 1:3] + [x[k2, N] for k2 ∈ 2:4]) + + # Test 9: Range inside function call + e = :(sum(x0[1:n])) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(sum([x[k, 0] for k ∈ 1:n])) + end + + @testset "mixed scalar and range" begin + # Test 10: Expression with both scalars and ranges + e = :(x0[1] + x0[2:4] + x0[5]) + result = subs2(e, :x0, :x, 0; k = :k) + # x0[1] → x[1, 0] + # x0[2:4] → [x[k, 0] for k ∈ 2:4] + # x0[5] → x[5, 0] + @test result == :(x[1, 0] + [x[k, 0] for k ∈ 2:4] + x[5, 0]) + end + + @testset "nested and complex expressions" begin + # Test 11: Nested function calls with ranges + e = :(norm(x0[1:3]) + cos(x0[4])) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(norm([x[k, 0] for k ∈ 1:3]) + cos(x[4, 0])) + + # Test 12: Range in matrix operations + e = :(A * x0[1:n]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(A * [x[k, 0] for k ∈ 1:n]) + + # Test 13: Multiple substitutions with symbolic j + e = :(x0[1:3] + xf[2:4]) + result = subs2(subs2(e, :x0, :x, :j; k = :k1), :xf, :x, :(j+1); k = :k2) + @test result == :([x[k1, j] for k1 ∈ 1:3] + [x[k2, j+1] for k2 ∈ 2:4]) + end + + @testset "edge cases" begin + # Test 14: Single-element range (should still create comprehension) + e = :(x0[1:1]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:1]) + + # Test 15: Wrong variable name (should not substitute) + e = :(y0[1:3]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == e # Unchanged + + # Test 16: Complex symbolic j expression + e = :(x0[1:3]) + result = subs2(e, :x0, :x, :grid_size; k = :k) + @test result == :([x[k, grid_size] for k ∈ 1:3]) + + # Test 17: Scalar index that is a range expression (should not match) + # This tests that we properly distinguish i (scalar) from 1:3 (range) + e = :(x0[i]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(x[i, 0]) # Scalar behavior + end + + @testset "backward compatibility" begin + # Test 18: Scalar indexing still works + e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) + + # Test 19: Bare symbols are NOT substituted + e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) + end end @testset "subs3" begin @@ -44,28 +149,59 @@ function test_utils() @test subs3(e, :xf, :x, 1, :N) == :(x0[1:2:d] * (2 * x[1, N])) end - @testset "subs4" begin - println("subs4") + @testset "subs2m" begin + println("subs2m") - e = :(v[1:2:d] * 2xf[1:3]) - @test subs4(e, :v, :v, :i) == :(v[i] * (2 * xf[1:3])) - @test subs4(e, :xf, :xf, 1) == :(v[1:2:d] * (2 * xf[1])) - end + @testset "range indexing" begin + # Test 1: Basic range substitution + e = :(x0[1:3]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:3]) + + # Test 2: Range with step + e = :(x0[1:2:5]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:2:5]) + + # Test 3: Range in arithmetic expression + e = :(2 * x0[1:3]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :(2 * [((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:3]) - @testset "subs5" begin - println("subs5") + # Test 4: Multiple ranges in same expression + e = :(x0[1:2] + xf[2:4]) + result = subs2m(subs2m(e, :x0, :x, 0; k = :k), :xf, :x, :N; k = :k) + @test result == :( + [((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:2] + + [((x[k, N] + x[k, N + 1]) / 2) for k ∈ 2:4] + ) - e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) - @test subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) == :( - ((x[1, 0] + x[1, 0 + 1]) / 2) * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - - cos((x[2, N] + x[2, N + 1]) / 2) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) - ) + # Test 5: Range with symbolic j + e = :(x0[1:3]) + result = subs2m(e, :x0, :x, :j; k = :k) + @test result == :([((x[k, j] + x[k, j + 1]) / 2) for k ∈ 1:3]) - e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) - @test subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) == :( - x0 * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - - cos(xf) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) - ) + # Test 6: Single-element range + e = :(x0[2:2]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 2:2]) + end + + @testset "backward compatibility" begin + # Test 7: Scalar indexing still works + e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) + @test subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) == :( + ((x[1, 0] + x[1, 0 + 1]) / 2) * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - + cos((x[2, N] + x[2, N + 1]) / 2) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) + ) + + # Test 8: Bare symbols are NOT substituted + e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) + @test subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) == :( + x0 * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - + cos(xf) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) + ) + end end @testset "replace_call" begin diff --git a/test/test_utils_bis.jl b/test/test_utils_bis.jl index 4b39ef3..6f48764 100644 --- a/test/test_utils_bis.jl +++ b/test/test_utils_bis.jl @@ -78,16 +78,15 @@ function test_utils_bis() @test constraint_type(e, t, t0, tf, x, u, v) == :variable_fun end - @testset "subs2/3/4/5 (pathological cases)" begin - println("subs2/3/4/5 (bis)") + @testset "subs2/2m/3 (pathological cases)" begin + println("subs2/2m/3 (bis)") e = :(x0[1] * 2xf[3]) # symbol does not appear at all → expression unchanged @test subs2(e, :z, :x, 0) == e + @test subs2m(e, :z, :x, 0) == e @test subs3(e, :z, :x, :i, 0) == e - @test subs4(e, :z, :z, :i) == e - @test subs5(e, :z, :x, 0) == e end @testset "replace_call (errors)" begin