Skip to content

Commit 06d4506

Browse files
committed
add pmpl files
1 parent 510cd43 commit 06d4506

File tree

3 files changed

+290
-0
lines changed

3 files changed

+290
-0
lines changed

src/common/pmpl.hpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#pragma once
2+
3+
4+
#include "ShivaMacros.hpp"
5+
6+
namespace shiva
7+
{
8+
#if defined(HPCREACT_USE_DEVICE)
9+
#if defined(HPCREACT_USE_CUDA)
10+
#define deviceMalloc( PTR, BYTES ) cudaMalloc( PTR, BYTES );
11+
#define deviceMallocManaged( PTR, BYTES ) cudaMallocManaged( PTR, BYTES );
12+
#define deviceDeviceSynchronize() cudaDeviceSynchronize();
13+
#define deviceMemCpy( DST, SRC, BYTES, KIND ) cudaMemcpy( DST, SRC, BYTES, KIND );
14+
#define deviceFree( PTR ) cudaFree( PTR );
15+
#elif defined(HPCREACT_USE_HIP)
16+
#define deviceMalloc( PTR, BYTES ) hipMalloc( PTR, BYTES );
17+
#define deviceMallocManaged( PTR, BYTES ) hipMallocManaged( PTR, BYTES );
18+
#define deviceDeviceSynchronize() hipDeviceSynchronize();
19+
#define deviceMemCpy( DST, SRC, BYTES, KIND ) hipMemcpy( DST, SRC, BYTES, KIND );
20+
#define deviceFree( PTR ) hipFree( PTR );
21+
#endif
22+
#endif
23+
24+
/**
25+
* @namespace shiva::pmpl
26+
* @brief The pmpl namespace contains all of the pmpl classes and functions
27+
* used to provide a portablity layer in unit testing.
28+
*/
29+
namespace pmpl
30+
{
31+
32+
/**
33+
* @brief This function checks if two floating point numbers are equal within a
34+
* tolerance.
35+
* @tparam REAL_TYPE This is the type of the floating point numbers to compare.
36+
* @param a This is the first floating point number to compare.
37+
* @param b This is the second floating point number to compare.
38+
* @param tolerance This is the tolerance to use when comparing the two numbers.
39+
* @return This returns true if the two numbers are equal within the tolerance.
40+
*/
41+
template< typename REAL_TYPE >
42+
static constexpr bool check( REAL_TYPE const a, REAL_TYPE const b, REAL_TYPE const tolerance )
43+
{
44+
return ( a - b ) * ( a - b ) < tolerance * tolerance;
45+
}
46+
47+
48+
/**
49+
* @brief This function provides a generic kernel execution mechanism that can
50+
* be called on either host or device.
51+
* @tparam LAMBDA The type of the lambda function to execute.
52+
* @param func The lambda function to execute.
53+
*/
54+
template< typename LAMBDA >
55+
HPCREACT_GLOBAL void genericKernel( LAMBDA func )
56+
{
57+
func();
58+
}
59+
60+
/**
61+
* @brief This function provides a wrapper to the genericKernel function.
62+
* @tparam LAMBDA The type of the lambda function to execute.
63+
* @param func The lambda function to execute.
64+
*
65+
* This function will execute the lambda through a kernel launch of
66+
* genericKernel.
67+
*/
68+
template< typename LAMBDA >
69+
void genericKernelWrapper( LAMBDA && func )
70+
{
71+
#if defined(HPCREACT_USE_DEVICE)
72+
genericKernel << < 1, 1 >> > ( std::forward< LAMBDA >( func ) );
73+
#else
74+
genericKernel( std::forward< LAMBDA >( func ) );
75+
#endif
76+
}
77+
78+
79+
80+
/**
81+
* @brief This function provides a generic kernel execution mechanism that can
82+
* be called on either host or device.
83+
* @tparam DATA_TYPE The type of the data pointer.
84+
* @tparam LAMBDA The type of the lambda function to execute.
85+
* @param func The lambda function to execute.
86+
* @param data A general data pointer to pass to the lambda function that should
87+
* hold all data required to execute the lambda function, aside from what is
88+
* captured.
89+
*/
90+
template< typename DATA_TYPE, typename LAMBDA >
91+
HPCREACT_GLOBAL void genericKernel( LAMBDA func, DATA_TYPE * const data )
92+
{
93+
func( data );
94+
}
95+
96+
/**
97+
* @brief This function provides a wrapper to the genericKernel function.
98+
* @tparam DATA_TYPE The type of the data pointer.
99+
* @tparam LAMBDA The type of the lambda function to execute.
100+
* @param N The size of the data array.
101+
* @param hostData The data pointer to pass to the lambda function.
102+
* @param func The lambda function to execute.
103+
*
104+
* This function will allocate the data pointer on the device, execute the
105+
* lambda through a kernel launch of genericKernel, and then synchronize the
106+
* device.
107+
*/
108+
template< typename DATA_TYPE, typename LAMBDA >
109+
void genericKernelWrapper( int const N, DATA_TYPE * const hostData, LAMBDA && func )
110+
{
111+
112+
#if defined(HPCREACT_USE_DEVICE)
113+
DATA_TYPE * deviceData;
114+
deviceMalloc( &deviceData, N * sizeof(DATA_TYPE) );
115+
genericKernel <<< 1, 1 >>> ( std::forward< LAMBDA >( func ), deviceData );
116+
deviceDeviceSynchronize();
117+
deviceMemCpy( hostData, deviceData, N * sizeof(DATA_TYPE), cudaMemcpyDeviceToHost );
118+
deviceFree( deviceData );
119+
#else
120+
HPCREACT_UNUSED_VAR( N );
121+
genericKernel( std::forward< LAMBDA >( func ), hostData );
122+
#endif
123+
}
124+
125+
/**
126+
* @brief convenience function for allocating data allocated on a pointer
127+
* @tparam DATA_TYPE The type of the data pointer.
128+
* @param data The data pointer to deallocate.
129+
*/
130+
template< typename DATA_TYPE >
131+
HPCREACT_CONSTEXPR_HOSTDEVICE_FORCEINLINE void deallocateData( DATA_TYPE * data )
132+
{
133+
#if defined(HPCREACT_USE_DEVICE)
134+
deviceFree( data );
135+
#else
136+
delete[] data;
137+
#endif
138+
}
139+
140+
} // namespace pmpl
141+
} // namespace shiva
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Specify list of tests
2+
set( testSourceFiles
3+
testReactionsParameterData.cpp )
4+
5+
set( dependencyList hpcReact gtest )
6+
7+
# Add gtest C++ based tests
8+
foreach(test ${testSourceFiles})
9+
get_filename_component( test_name ${test} NAME_WE )
10+
blt_add_executable( NAME ${test_name}
11+
SOURCES ${test}
12+
OUTPUT_DIR ${TEST_OUTPUT_DIRECTORY}
13+
DEPENDS_ON ${dependencyList} )
14+
blt_add_test( NAME ${test_name}
15+
COMMAND ${test_name} )
16+
endforeach()
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
2+
#include "../ReactionsBase_impl.hpp"
3+
#include "MultiVector.hpp"
4+
#include "common/CArrayWrapper.hpp"
5+
#include "../ReactionsParameterDataPredefined.hpp"
6+
7+
#include <gtest/gtest.h>
8+
9+
using namespace hpcReact;
10+
11+
template< typename T >
12+
constexpr bool nearEqual( T a, T b, T tol = 1e-7 )
13+
{
14+
return ( a - b ) * ( a - b ) < tol * tol;
15+
}
16+
17+
TEST( testReactionsBase, testParamsInitialization )
18+
{
19+
static_assert( chemicalReactionsParams.numPrimarySpecies == 7, "Number of primary species is not 7" );
20+
static_assert( chemicalReactionsParams.numSecondarySpecies == 11, "Number of secondary species is not 11" );
21+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizePrimary[0], 9.00 ), "Ion size for primary species 0 is not 9.00" );
22+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizePrimary[1], 4.00 ), "Ion size for primary species 1 is not 4.00" );
23+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizePrimary[2], 6.00 ), "Ion size for primary species 2 is not 6.00" );
24+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizePrimary[3], 4.00 ), "Ion size for primary species 3 is not 4.00" );
25+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizePrimary[4], 3.00 ), "Ion size for primary species 4 is not 3.00" );
26+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizePrimary[5], 8.00 ), "Ion size for primary species 5 is not 8.00" );
27+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizePrimary[6], 4.00 ), "Ion size for primary species 6 is not 4.00" );
28+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[0] , 3.50 ), "Ion size for secondary species 0 is not 3.50" );
29+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[1] , 3.00 ), "Ion size for secondary species 1 is not 3.00" );
30+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[2] , 4.50 ), "Ion size for secondary species 2 is not 4.50" );
31+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[3] , 3.00 ), "Ion size for secondary species 3 is not 3.00" );
32+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[4] , 4.00 ), "Ion size for secondary species 4 is not 4.00" );
33+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[5] , 3.00 ), "Ion size for secondary species 5 is not 3.00" );
34+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[6] , 3.00 ), "Ion size for secondary species 6 is not 3.00" );
35+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[7] , 4.00 ), "Ion size for secondary species 7 is not 4.00" );
36+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[8] , 3.00 ), "Ion size for secondary species 8 is not 3.00" );
37+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[9] , 3.00 ), "Ion size for secondary species 9 is not 3.00" );
38+
static_assert( nearEqual( chemicalReactionsParams.m_ionSizeSec[10], 4.00 ), "Ion size for secondary species 10 is not 4.00" );
39+
static_assert( chemicalReactionsParams.m_chargePrimary[0] == 1, "Charge for primary species 0 is not 1" );
40+
static_assert( chemicalReactionsParams.m_chargePrimary[1] == -1, "Charge for primary species 1 is not -1" );
41+
static_assert( chemicalReactionsParams.m_chargePrimary[2] == 2, "Charge for primary species 2 is not 2" );
42+
static_assert( chemicalReactionsParams.m_chargePrimary[3] == -2, "Charge for primary species 3 is not -2" );
43+
static_assert( chemicalReactionsParams.m_chargePrimary[4] == -1, "Charge for primary species 4 is not -1" );
44+
static_assert( chemicalReactionsParams.m_chargePrimary[5] == 2, "Charge for primary species 5 is not 2" );
45+
static_assert( chemicalReactionsParams.m_chargePrimary[6] == 1, "Charge for primary species 6 is not 1" );
46+
static_assert( chemicalReactionsParams.m_chargeSec[0] == -1, "Charge for secondary species 0 is not -1" );
47+
static_assert( chemicalReactionsParams.m_chargeSec[1] == 0, "Charge for secondary species 1 is not 0" );
48+
static_assert( chemicalReactionsParams.m_chargeSec[2] == -2, "Charge for secondary species 2 is not -2" );
49+
static_assert( chemicalReactionsParams.m_chargeSec[3] == 0, "Charge for secondary species 3 is not 0" );
50+
static_assert( chemicalReactionsParams.m_chargeSec[4] == 1, "Charge for secondary species 4 is not 1" );
51+
static_assert( chemicalReactionsParams.m_chargeSec[5] == 0, "Charge for secondary species 5 is not 0" );
52+
static_assert( chemicalReactionsParams.m_chargeSec[6] == 0, "Charge for secondary species 6 is not 0" );
53+
static_assert( chemicalReactionsParams.m_chargeSec[7] == 1, "Charge for secondary species 7 is not 1" );
54+
static_assert( chemicalReactionsParams.m_chargeSec[8] == 0, "Charge for secondary species 8 is not 0" );
55+
static_assert( chemicalReactionsParams.m_chargeSec[9] == 0, "Charge for secondary species 9 is not 0" );
56+
static_assert( chemicalReactionsParams.m_chargeSec[10] == -1, "Charge for secondary species 10 is not -1" );
57+
static_assert( nearEqual( chemicalReactionsParams.m_DebyeHuckelA, 0.5465 ), "Debye Huckel A is not 0.5465" );
58+
static_assert( nearEqual( chemicalReactionsParams.m_DebyeHuckelB, 0.3346 ), "Debye Huckel B is not 0.3346" );
59+
static_assert( nearEqual( chemicalReactionsParams.m_WATEQBDot, 0.0438 ), "WATEQBDot is not 0.0438" );
60+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[0][0], -1.0 ), "Stoichiometry matrix for species 0 and 0 is not -1" );
61+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[0][1], 0.0 ), "Stoichiometry matrix for species 0 and 1 is not 0" );
62+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[0][2], 0.0 ), "Stoichiometry matrix for species 0 and 2 is not 0" );
63+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[0][3], 0.0 ), "Stoichiometry matrix for species 0 and 3 is not 0" );
64+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[0][4], 0.0 ), "Stoichiometry matrix for species 0 and 4 is not 0" );
65+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[0][5], 0.0 ), "Stoichiometry matrix for species 0 and 5 is not 0" );
66+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[0][6], 0.0 ), "Stoichiometry matrix for species 0 and 6 is not 0" );
67+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[1][0], 1.0 ), "Stoichiometry matrix for species 1 and 0 is not 1" );
68+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[1][1], 1.0 ), "Stoichiometry matrix for species 1 and 1 is not 1" );
69+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[1][2], 0.0 ), "Stoichiometry matrix for species 1 and 2 is not 0" );
70+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[1][3], 0.0 ), "Stoichiometry matrix for species 1 and 3 is not 0" );
71+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[1][4], 0.0 ), "Stoichiometry matrix for species 1 and 4 is not 0" );
72+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[1][5], 0.0 ), "Stoichiometry matrix for species 1 and 5 is not 0" );
73+
static_assert( nearEqual( chemicalReactionsParams.m_eqStoichMatrix[1][6], 0.0 ), "Stoichiometry matrix for species 1 and 6 is not 0" );
74+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[0], 13.99 ), "Log10 equilibrium constant for species 0 is not 13.99" );
75+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[1], -6.36 ) , "Log10 equilibrium constant for species 1 is not -6.36" );
76+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[2], 10.33 ), "Log10 equilibrium constant for species 2 is not 10.33" );
77+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[3], -3.77 ) , "Log10 equilibrium constant for species 3 is not -3.77" );
78+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[4], -1.09 ) , "Log10 equilibrium constant for species 4 is not -1.09" );
79+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[5], 7.07 ) , "Log10 equilibrium constant for species 5 is not 7.07" );
80+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[6], - 2.16 ) , "Log10 equilibrium constant for species 6 is not -2.16" );
81+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[7], + 0.67 ) , "Log10 equilibrium constant for species 7 is not 0.67" );
82+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[8], + 0.60 ) , "Log10 equilibrium constant for species 8 is not 0.60" );
83+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[9], - 2.43 ) , "Log10 equilibrium constant for species 9 is not -2.43" );
84+
static_assert( nearEqual( chemicalReactionsParams.m_eqLog10EqConst[10], - 0.82 ), "Log10 equilibrium constant for species 10 is not -0.82" );
85+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[0][0],-2.0 ), "Stoichiometry matrix for kinetic reaction 0 and species 0 is not -2" );
86+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[0][1], 0.0 ), "Stoichiometry matrix for kinetic reaction 0 and species 1 is not 0" );
87+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[0][2], 1.0 ), "Stoichiometry matrix for kinetic reaction 0 and species 2 is not 1" );
88+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[0][3], 0.0 ), "Stoichiometry matrix for kinetic reaction 0 and species 3 is not 0" );
89+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[0][4], 0.0 ), "Stoichiometry matrix for kinetic reaction 0 and species 4 is not 0" );
90+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[0][5], 0.0 ), "Stoichiometry matrix for kinetic reaction 0 and species 5 is not 0" );
91+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[0][6], 0.0 ), "Stoichiometry matrix for kinetic reaction 0 and species 6 is not 0" );
92+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[1][0],-1.0 ), "Stoichiometry matrix for kinetic reaction 1 and species 0 is not -1" );
93+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[1][1], 1.0 ), "Stoichiometry matrix for kinetic reaction 1 and species 1 is not 1" );
94+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[1][2], 1.0 ), "Stoichiometry matrix for kinetic reaction 1 and species 2 is not 1" );
95+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[1][3], 0.0 ), "Stoichiometry matrix for kinetic reaction 1 and species 3 is not 0" );
96+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[1][4], 0.0 ), "Stoichiometry matrix for kinetic reaction 1 and species 4 is not 0" );
97+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[1][5], 0.0 ), "Stoichiometry matrix for kinetic reaction 1 and species 5 is not 0" );
98+
static_assert( nearEqual( chemicalReactionsParams.m_kineticStoichMatrix[1][6], 0.0 ), "Stoichiometry matrix for kinetic reaction 1 and species 6 is not 0" );
99+
static_assert( nearEqual( chemicalReactionsParams.m_kineticlog10EqConst[0], 20.19 ), "Log10 equilibrium constant for kinetic reaction 0 is not 20.19" );
100+
static_assert( nearEqual( chemicalReactionsParams.m_kineticlog10EqConst[1], 1.32 ), "Log10 equilibrium constant for kinetic reaction 1 is not 1.32" );
101+
static_assert( nearEqual( chemicalReactionsParams.m_kineticReactionRateConstant[0], 9.95e-1 ), "Reaction rate constant for kinetic reaction 0 is not 9.95e-1" );
102+
static_assert( nearEqual( chemicalReactionsParams.m_kineticReactionRateConstant[1], 9.95e-3 ), "Reaction rate constant for kinetic reaction 1 is not 9.95e-3" );
103+
static_assert( nearEqual( chemicalReactionsParams.m_kineticSpecificSurfaceArea, 1.0 ), "Specific surface area is not 1.0" );
104+
105+
106+
}
107+
108+
TEST( testReactionsBase, testEquilibriumParamsExtraction )
109+
{
110+
constexpr auto eqParams = chemicalReactionsParams.equilibriumReactions();
111+
static_assert( eqParams.numPrimarySpecies == 7, "Number of primary species is not 7" );
112+
static_assert( eqParams.numSecondarySpecies == 11, "Number of secondary species is not 11" );
113+
static_assert( nearEqual( eqParams.m_ionSizePrimary[0], 9.00 ), "Ion size for primary species 0 is not 9.00" );
114+
static_assert( nearEqual( eqParams.m_ionSizePrimary[1], 4.00 ), "Ion size for primary species 1 is not 4.00" );
115+
static_assert( nearEqual( eqParams.m_ionSizePrimary[2], 6.00 ), "Ion size for primary species 2 is not 6.00" );
116+
static_assert( nearEqual( eqParams.m_ionSizePrimary[3], 4.00 ), "Ion size for primary species 3 is not 4.00" );
117+
static_assert( nearEqual( eqParams.m_ionSizePrimary[4], 3.00 ), "Ion size for primary species 4 is not 3.00" );
118+
static_assert( nearEqual( eqParams.m_ionSizePrimary[5], 8.00 ), "Ion size for primary species 5 is not 8.00" );
119+
static_assert( nearEqual( eqParams.m_ionSizePrimary[6], 4.00 ), "Ion size for primary species 6 is not 4.00" );
120+
static_assert( nearEqual( eqParams.m_ionSizeSec[0] , 3.50 ), "Ion size for secondary species 0 is not 3.50" );
121+
static_assert( nearEqual( eqParams.m_ionSizeSec[1] , 3.00 ), "Ion size for secondary species 1 is not 3.00" );
122+
static_assert( nearEqual( eqParams.m_ionSizeSec[2] , 4.50 ), "Ion size for secondary species 2 is not 4.50" );
123+
static_assert( nearEqual( eqParams.m_ionSizeSec[3] , 3.00 ), "Ion size for secondary species 3 is not 3.00" );
124+
static_assert( nearEqual( eqParams.m_ionSizeSec[
125+
}
126+
127+
128+
int main( int argc, char * * argv )
129+
{
130+
::testing::InitGoogleTest( &argc, argv );
131+
int const result = RUN_ALL_TESTS();
132+
return result;
133+
}

0 commit comments

Comments
 (0)