Skip to content
12 changes: 7 additions & 5 deletions src/MeshField.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ class OmegahMeshField {
static_assert(dim == 1 || dim == 2 || dim == 3);
}

template <typename DataType, size_t order>
template <typename DataType, size_t order, size_t numComp>
// Ordering of field indexing changed to 'entity, node, component'
auto CreateLagrangeField() {
return MeshField::CreateLagrangeField<ExecutionSpace, Controller, DataType,
order, dim>(meshInfo);
order, dim, numComp>(meshInfo);
}

auto getCoordField() { return coordField; }
Expand All @@ -219,12 +219,13 @@ class OmegahMeshField {
return offsets;
}

// evaluate a field at the specified local coordinates for each triangle
// evaluate a field at the specified local coordinate for each triangle
template <typename ViewType, typename ShapeField>
auto triangleLocalPointEval(ViewType localCoords, size_t NumPtsPerElem,
ShapeField field) {
auto offsets = createOffsets(meshInfo.numTri, NumPtsPerElem);
auto eval = triangleLocalPointEval(localCoords, offsets, field);
auto eval = triangleLocalPointEval<ViewType, ShapeField>(localCoords,
offsets, field);
return eval;
}

Expand All @@ -243,7 +244,8 @@ class OmegahMeshField {

const auto [shp, map] = Omegah::getTriangleElement<ShapeOrder>(mesh);

MeshField::FieldElement f(meshInfo.numTri, field, shp, map);
MeshField::FieldElement<ShapeField, decltype(shp), decltype(map)> f(
meshInfo.numTri, field, shp, map);
auto eval = MeshField::evaluate(f, localCoords, offsets);
return eval;
}
Expand Down
8 changes: 4 additions & 4 deletions src/MeshField_Element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ struct FieldElement {
};
using ValArray =
Kokkos::Array<typename baseType<typename FieldAccessor::BaseType>::type,
ShapeType::numComponentsPerDof>;
static const size_t NumComponents = ShapeType::numComponentsPerDof;
FieldAccessor::numComp>;
static const size_t NumComponents = FieldAccessor::numComp;

/**
* @brief
Expand All @@ -173,11 +173,11 @@ struct FieldElement {
assert(ent < numMeshEnts);
ValArray c;
const auto shapeValues = shapeFn.getValues(localCoord);
for (int ci = 0; ci < shapeFn.numComponentsPerDof; ++ci)
for (int ci = 0; ci < NumComponents; ++ci)
c[ci] = 0;
for (auto topo : elm2dof.getTopology()) { // element topology
for (int ni = 0; ni < shapeFn.numNodes; ++ni) {
for (int ci = 0; ci < shapeFn.numComponentsPerDof; ++ci) {
for (int ci = 0; ci < NumComponents; ++ci) {
auto map = elm2dof(ni, ci, ent, topo);
const auto fval =
field(map.entity, map.node, map.component, map.topo);
Expand Down
4 changes: 0 additions & 4 deletions src/MeshField_Shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ using Vector4 = Kokkos::Array<Real, 4>;

struct LinearEdgeShape {
static const size_t numNodes = 2;
static const size_t numComponentsPerDof = 1;
static const size_t meshEntDim = 1;
constexpr static Mesh_Topology DofHolders[1] = {Vertex};
constexpr static size_t Order = 1;
Expand Down Expand Up @@ -90,7 +89,6 @@ struct LinearTriangleShape {

struct LinearTriangleCoordinateShape {
static const size_t numNodes = 3;
static const size_t numComponentsPerDof = 2;
static const size_t meshEntDim = 2;
constexpr static Mesh_Topology DofHolders[1] = {Vertex};
constexpr static size_t Order = 1;
Expand All @@ -109,7 +107,6 @@ struct LinearTriangleCoordinateShape {

struct QuadraticTriangleShape {
static const size_t numNodes = 6;
static const size_t numComponentsPerDof = 1;
static const size_t meshEntDim = 2;
constexpr static Mesh_Topology DofHolders[2] = {Vertex, Edge};
constexpr static size_t NumDofHolders[2] = {3, 3};
Expand Down Expand Up @@ -149,7 +146,6 @@ struct QuadraticTriangleShape {

struct QuadraticTetrahedronShape {
static const size_t numNodes = 10;
static const size_t numComponentsPerDof = 1;
static const size_t meshEntDim = 3;
constexpr static Mesh_Topology DofHolders[2] = {Vertex, Edge};
constexpr static size_t NumDofHolders[2] = {4, 6};
Expand Down
31 changes: 18 additions & 13 deletions src/MeshField_ShapeField.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ struct MeshInfo {
* @param meshInfoIn defines on-process mesh metadata
* @param mixins object(s) needed to construct the Accessor
*/
template <typename MeshFieldType, typename Shape, typename... Mixins>
template <size_t numCompIn, typename MeshFieldType, typename Shape,
typename... Mixins>
struct ShapeField : public Mixins... {
MeshFieldType meshField;
Shape shape;
static const size_t numComp = numCompIn;
const MeshInfo meshInfo;
constexpr static auto Order = Shape::Order;
ShapeField(MeshFieldType &meshFieldIn, const MeshInfo &meshInfoIn,
Expand Down Expand Up @@ -158,7 +160,7 @@ template <typename VtxAccessor> struct LinearAccessor {
template <typename ExecutionSpace,
template <typename...>
typename Controller = MeshField::KokkosController,
typename DataType, size_t order, size_t dim>
typename DataType, size_t order, size_t dim, size_t numComp>
auto CreateLagrangeField(const MeshInfo &meshInfo) {
static_assert((std::is_same_v<Real4, DataType> == true ||
std::is_same_v<Real8, DataType> == true),
Expand All @@ -179,10 +181,10 @@ auto CreateLagrangeField(const MeshInfo &meshInfo) {
std::is_same_v<
Controller<ExecutionSpace, MemorySpace, DataType>,
MeshField::CabanaController<ExecutionSpace, MemorySpace, DataType>>,
Controller<ExecutionSpace, MemorySpace, DataType[1][1]>,
Controller<ExecutionSpace, MemorySpace, DataType[1][numComp]>,
Controller<MemorySpace, ExecutionSpace, DataType ***>>;
// 1 dof with 1 component per vtx
auto createController = [](const int numComp, auto numVtx) {
auto createController = [](auto numVtx) {
if constexpr (std::is_same_v<
Controller<ExecutionSpace, MemorySpace, DataType>,
MeshField::CabanaController<ExecutionSpace, MemorySpace,
Expand All @@ -192,14 +194,15 @@ auto CreateLagrangeField(const MeshInfo &meshInfo) {
return Ctrlr({/*field 0*/ numVtx, 1, numComp});
}
};
Ctrlr kk_ctrl = createController(1, meshInfo.numVtx);
Ctrlr kk_ctrl = createController(meshInfo.numVtx);
#else
using Ctrlr = Controller<MemorySpace, ExecutionSpace, DataType ***>;
Ctrlr kk_ctrl({/*field 0*/ meshInfo.numVtx, 1, 1});
#endif
auto vtxField = MeshField::makeField<Ctrlr, 0>(kk_ctrl);
using LA = LinearAccessor<decltype(vtxField)>;
using LinearLagrangeShapeField = ShapeField<Ctrlr, LinearTriangleShape, LA>;
using LinearLagrangeShapeField =
ShapeField<numComp, Ctrlr, LinearTriangleShape, LA>;
LinearLagrangeShapeField llsf(kk_ctrl, meshInfo, {vtxField});
return llsf;
} else if constexpr (order == 2 && (dim == 2 || dim == 3)) {
Expand All @@ -214,10 +217,11 @@ auto CreateLagrangeField(const MeshInfo &meshInfo) {
std::is_same_v<
Controller<ExecutionSpace, MemorySpace, DataType>,
MeshField::CabanaController<ExecutionSpace, MemorySpace, DataType>>,
Controller<ExecutionSpace, MemorySpace, DataType[1][1], DataType[1][1]>,
Controller<ExecutionSpace, MemorySpace, DataType[1][numComp],
DataType[1][numComp]>,
Controller<MemorySpace, ExecutionSpace, DataType ***, DataType ***>>;
// 1 dof with 1 comp per vtx/edge
auto createController = [](const int numComp, auto numVtx, auto numEdge) {
auto createController = [](auto numVtx, auto numEdge) {
if constexpr (std::is_same_v<
Controller<ExecutionSpace, MemorySpace, DataType>,
MeshField::CabanaController<ExecutionSpace, MemorySpace,
Expand All @@ -228,18 +232,18 @@ auto CreateLagrangeField(const MeshInfo &meshInfo) {
/*field 1*/ numEdge, 1, numComp});
}
};
Ctrlr kk_ctrl = createController(1, meshInfo.numVtx, meshInfo.numEdge);
Ctrlr kk_ctrl = createController(meshInfo.numVtx, meshInfo.numEdge);
#else
using Ctrlr =
Controller<MemorySpace, ExecutionSpace, DataType ***, DataType ***>;
Ctrlr kk_ctrl({/*field 0*/ meshInfo.numVtx, 1, 1,
/*field 1*/ meshInfo.numEdge, 1, 1});
Ctrlr kk_ctrl({/*field 0*/ meshInfo.numVtx, 1, numComp,
/*field 1*/ meshInfo.numEdge, 1, numComp});
#endif
auto vtxField = MeshField::makeField<Ctrlr, 0>(kk_ctrl);
auto edgeField = MeshField::makeField<Ctrlr, 1>(kk_ctrl);
using QA = QuadraticAccessor<decltype(vtxField), decltype(edgeField)>;
using QuadraticLagrangeShapeField =
ShapeField<Ctrlr, QuadraticTriangleShape, QA>;
ShapeField<numComp, Ctrlr, QuadraticTriangleShape, QA>;
QuadraticLagrangeShapeField qlsf(kk_ctrl, meshInfo, {vtxField, edgeField});
return qlsf;
} else {
Expand Down Expand Up @@ -302,7 +306,8 @@ auto CreateCoordinateField(const MeshInfo &meshInfo) {
#endif
auto vtxField = MeshField::makeField<Ctrlr, 0>(kk_ctrl);
using LA = LinearAccessor<decltype(vtxField)>;
using LinearLagrangeShapeField = ShapeField<Ctrlr, LinearTriangleShape, LA>;
using LinearLagrangeShapeField =
ShapeField<dim, Ctrlr, LinearTriangleShape, LA>;
LinearLagrangeShapeField llsf(kk_ctrl, meshInfo, {vtxField});
return llsf;
};
Expand Down
8 changes: 4 additions & 4 deletions test/testElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void triangleLocalPointEval() {
const auto numElms = 3; // provided by the mesh
const MeshField::MeshInfo meshInfo{.numVtx = 5, .numTri = 3};
auto field = MeshField::CreateLagrangeField<
ExecutionSpace, MeshField::KokkosController, MeshField::Real, 1, 2>(
ExecutionSpace, MeshField::KokkosController, MeshField::Real, 1, 2, 1>(
meshInfo);

MeshField::FieldElement f(numElms, field, MeshField::LinearTriangleShape(),
Expand Down Expand Up @@ -85,7 +85,7 @@ struct LinearEdgeToVertexField {
void edgeLocalPointEval() {
const MeshField::MeshInfo meshInfo{.numVtx = 5, .numEdge = 7, .dim = 1};
auto field = MeshField::CreateLagrangeField<
ExecutionSpace, MeshField::KokkosController, MeshField::Real, 1, 1>(
ExecutionSpace, MeshField::KokkosController, MeshField::Real, 1, 1, 1>(
meshInfo);

MeshField::FieldElement f(meshInfo.numEdge, field,
Expand Down Expand Up @@ -137,7 +137,7 @@ void quadraticTriangleLocalPointEval() {
const MeshField::MeshInfo meshInfo{
.numVtx = 3, .numEdge = 3, .numTri = 1, .dim = 2};
auto field = MeshField::CreateLagrangeField<
ExecutionSpace, MeshField::KokkosController, MeshField::Real, 2, 2>(
ExecutionSpace, MeshField::KokkosController, MeshField::Real, 2, 2, 1>(
meshInfo);

MeshField::FieldElement f(meshInfo.numTri, field,
Expand Down Expand Up @@ -192,7 +192,7 @@ void quadraticTetrahedronLocalPointEval() {
auto field =
MeshField::CreateLagrangeField<ExecutionSpace,
MeshField::KokkosController,
MeshField::Real, ShapeOrder, MeshDim>(
MeshField::Real, ShapeOrder, MeshDim, 1>(
meshInfo);

MeshField::FieldElement f(meshInfo.numTet, field,
Expand Down
47 changes: 40 additions & 7 deletions test/testOmegahElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ void doRun(Omega_h::Mesh &mesh,
using ViewType = decltype(testCase.coords);
{
const auto ShapeOrder = 1;

auto field =
omf.template CreateLagrangeField<MeshField::Real, ShapeOrder>();
const auto numComponents = 1;
auto field = omf.template CreateLagrangeField<MeshField::Real, ShapeOrder,
numComponents>();
LinearFunction func = LinearFunction();
setVertices(mesh, func, field);
using FieldType = decltype(field);
Expand All @@ -186,8 +186,9 @@ void doRun(Omega_h::Mesh &mesh,

{
const auto ShapeOrder = 2;
auto field =
omf.template CreateLagrangeField<MeshField::Real, ShapeOrder>();
const auto numComponents = 1;
auto field = omf.template CreateLagrangeField<MeshField::Real, ShapeOrder,
numComponents>();
auto func = QuadraticFunction();
setVertices(mesh, func, field);
setEdges(mesh, func, field);
Expand All @@ -202,8 +203,9 @@ void doRun(Omega_h::Mesh &mesh,

{
const auto ShapeOrder = 2;
auto field =
omf.template CreateLagrangeField<MeshField::Real, ShapeOrder>();
const auto numComponents = 1;
auto field = omf.template CreateLagrangeField<MeshField::Real, ShapeOrder,
numComponents>();
auto func = LinearFunction();
setVertices(mesh, func, field);
setEdges(mesh, func, field);
Expand All @@ -215,6 +217,37 @@ void doRun(Omega_h::Mesh &mesh,
if (failed)
doFail("quadratic", "linear", testCase.name);
}

{
const auto ShapeOrder = 1;
const auto numComponents = 2;
auto field = omf.template CreateLagrangeField<MeshField::Real, ShapeOrder,
numComponents>();
LinearFunction func = LinearFunction();
setVertices(mesh, func, field);
using FieldType = decltype(field);
auto result = omf.template triangleLocalPointEval<ViewType, FieldType>(
testCase.coords, testCase.NumPtsPerElem, field);
auto failed = checkResult(mesh, result, omf.getCoordField(), testCase,
LinearFunction{});
if (failed)
doFail("linear", "linear", testCase.name);
}
{
const auto ShapeOrder = 1;
const auto numComponents = 3;
auto field = omf.template CreateLagrangeField<MeshField::Real, ShapeOrder,
numComponents>();
LinearFunction func = LinearFunction();
setVertices(mesh, func, field);
using FieldType = decltype(field);
auto result = omf.template triangleLocalPointEval<ViewType, FieldType>(
testCase.coords, testCase.NumPtsPerElem, field);
auto failed = checkResult(mesh, result, omf.getCoordField(), testCase,
LinearFunction{});
if (failed)
doFail("linear", "linear", testCase.name);
}
}
}

Expand Down
Loading