Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function _integral(
# Create a wrapper that returns only the value component in those units
uintegrand(ts) = Unitful.ustrip.(integrandunits, integrand(ts))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, zeros(FP, N), ones(FP, N); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, _zeros(FP, N), _ones(FP, N); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
2 changes: 1 addition & 1 deletion src/specializations/BezierCurve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function integral(
# Create a wrapper that returns only the value component in those units
uintegrand(ts) = Unitful.ustrip.(integrandunits, integrand(ts))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, zeros(FP, 1), ones(FP, 1); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, _zeros(FP, 1), _ones(FP, 1); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
4 changes: 2 additions & 2 deletions src/specializations/Line.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ function integral(

# Integrate f along the Line
differential(line, x) = t′(x) * _units(line(0))
integrand(x::AbstractVector) = f(line(t(x[1]))) * differential(line, x[1])
integrand(xs) = f(line(t(xs[1]))) * differential(line, xs[1])

# HCubature doesn't support functions that output Unitful Quantity types
# Establish the units that are output by f
testpoint_parametriccoord = FP[0.5]
testpoint_parametriccoord = (FP(0.5),)
integrandunits = Unitful.unit.(integrand(testpoint_parametriccoord))
# Create a wrapper that returns only the value component in those units
uintegrand(uv) = Unitful.ustrip.(integrandunits, integrand(uv))
Expand Down
4 changes: 3 additions & 1 deletion src/specializations/Plane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ function integral(
# Create a wrapper that returns only the value component in those units
uintegrand(uv) = Unitful.ustrip.(integrandunits, integrand(uv))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, -ones(FP, 2), ones(FP, 2); rule.kwargs...)[1]
a = 0 .- _ones(FP, 2)
b = _ones(FP, 2)
value = HCubature.hcubature(uintegrand, a, b; rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
2 changes: 1 addition & 1 deletion src/specializations/Ray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function integral(
# Create a wrapper that returns only the value component in those units
uintegrand(uv) = Unitful.ustrip.(integrandunits, integrand(uv))
# Integrate only the unitless values
value = HCubature.hcubature(uintegrand, zeros(FP, 1), ones(FP, 1); rule.kwargs...)[1]
value = HCubature.hcubature(uintegrand, _zeros(FP, 1), _ones(FP, 1); rule.kwargs...)[1]

# Reapply units
return value .* integrandunits
Expand Down
2 changes: 1 addition & 1 deletion src/specializations/Triangle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function integral(
v = R * (1 - b / (a + b))
return f(triangle(u, v)) * R / (a + b)^2
end
∫ = HCubature.hcubature(integrand, zeros(FP, 2), FP[1, π / 2], rule.kwargs...)[1]
∫ = HCubature.hcubature(integrand, _zeros(FP, 2), (FP(1), FP(π / 2)), rule.kwargs...)[1]

# Apply a linear domain-correction factor 0.5 ↦ area(triangle)
return 2 * Meshes.area(triangle) .* ∫
Expand Down
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ function _error_unsupported_combination(geometry, rule)
throw(ArgumentError(msg))
end

# Return an NTuple{N, T} of zeros; same interface as Base.zeros() but faster
_zeros(T::DataType, N::Int64) = ntuple(_ -> zero(T), N)
_zeros(N::Int) = _zeros(Float64, N)

# Return an NTuple{N, T} of ones; same interface as Base.ones() but faster
_ones(T::DataType, N::Int64) = ntuple(_ -> one(T), N)
_ones(N::Int) = _ones(Float64, N)

################################################################################
# DifferentiationMethod
################################################################################
Expand Down
11 changes: 10 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
@testitem "Utilities" setup=[Setup] begin
using LinearAlgebra: norm
using MeshIntegrals: _units, _zeros, _ones

# _KVector
v = Meshes.Vec(3, 4)
@test norm(MeshIntegrals._KVector(v)) ≈ 5.0u"m"

# _units
p = Point(1.0u"cm", 2.0u"mm", 3.0u"m")
@test MeshIntegrals._units(p) == u"m"
@test _units(p) == u"m"

# _zeros
@test _zeros(2) == (0.0, 0.0)
@test _zeros(Float32, 2) == (0.0f0, 0.0f0)

# _ones
@test _ones(2) == (1.0, 1.0)
@test _ones(Float32, 2) == (1.0f0, 1.0f0)
end

@testitem "DifferentiationMethod" setup=[Setup] begin
Expand Down