Skip to content

Commit ea2fd57

Browse files
authored
Merge pull request #166 from amcamd/refactor_tests
refactor set/get vector/matrix for host/dev ptr mode
2 parents 49e04c7 + 130ef8b commit ea2fd57

File tree

8 files changed

+269
-231
lines changed

8 files changed

+269
-231
lines changed

clients/benchmarks/client.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "testing_trtri.hpp"
2323
#include "testing_trtri_batched.hpp"
2424
#include "testing_geam.hpp"
25+
#include "testing_set_get_vector.hpp"
26+
#include "testing_set_get_matrix.hpp"
2527
#if BUILD_WITH_TENSILE
2628
#include "testing_gemm.hpp"
2729
#include "testing_gemm_strided_batched.hpp"
@@ -171,6 +173,18 @@ int main(int argc, char *argv[])
171173
else if (precision == 'd')
172174
testing_geam<double>( argus );
173175
}
176+
else if (function == "set_get_vector"){
177+
if (precision == 's')
178+
testing_set_get_vector<float>( argus );
179+
else if (precision == 'd')
180+
testing_set_get_vector<double>( argus );
181+
}
182+
else if (function == "set_get_matrix"){
183+
if (precision == 's')
184+
testing_set_get_matrix<float>( argus );
185+
else if (precision == 'd')
186+
testing_set_get_matrix<double>( argus );
187+
}
174188
#if BUILD_WITH_TENSILE
175189
else if (function == "gemm"){
176190

clients/common/arg_check.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ void set_get_matrix_arg_check(rocblas_status status, rocblas_int rows, rocblas_i
3535
#endif
3636
}
3737

38+
void set_get_vector_arg_check(rocblas_status status, rocblas_int M, rocblas_int incx,
39+
rocblas_int incy, rocblas_int incd)
40+
{
41+
#ifdef GOOGLE_TEST
42+
ASSERT_EQ(status, rocblas_status_invalid_size);
43+
#else
44+
if (status != rocblas_status_invalid_size)
45+
{
46+
std::cerr << "ERROR in arguments M, incx, incy, incd: ";
47+
std::cerr << M << ',' << incx << ',' << incy << ',' << incd << std::endl;
48+
}
49+
#endif
50+
}
51+
3852
void gemv_ger_arg_check(rocblas_status status, rocblas_int M, rocblas_int N, rocblas_int lda,
3953
rocblas_int incx, rocblas_int incy)
4054
{

clients/gtest/set_get_matrix_gtest.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,22 +159,44 @@ TEST_P(set_matrix_get_matrix_gtest, float)
159159
rocblas_status status = testing_set_get_matrix<float>( arg );
160160

161161
// if not success, then the input argument is problematic, so detect the error message
162-
if(status != rocblas_status_success){
163-
if( arg.rows < 0 ){
162+
if (status != rocblas_status_success)
163+
{
164+
if (arg.rows < 0)
165+
{
164166
EXPECT_EQ(rocblas_status_invalid_size, status);
165167
}
166-
else if(arg.cols <= 0){
168+
else if (arg.cols <= 0)
169+
{
167170
EXPECT_EQ(rocblas_status_invalid_size, status);
168171
}
169-
else if(arg.lda <= 0){
172+
else if (arg.lda <= 0)
173+
{
170174
EXPECT_EQ(rocblas_status_invalid_size, status);
171175
}
172-
else if(arg.ldb <= 0){
176+
else if (arg.ldb <= 0)
177+
{
173178
EXPECT_EQ(rocblas_status_invalid_size, status);
174179
}
175-
else if(arg.ldc <= 0){
180+
else if (arg.ldc <= 0)
181+
{
176182
EXPECT_EQ(rocblas_status_invalid_size, status);
177183
}
184+
else if (arg.lda < arg.rows)
185+
{
186+
EXPECT_EQ(rocblas_status_invalid_size, status);
187+
}
188+
else if (arg.ldb < arg.rows)
189+
{
190+
EXPECT_EQ(rocblas_status_invalid_size, status);
191+
}
192+
else if (arg.ldc < arg.rows)
193+
{
194+
EXPECT_EQ(rocblas_status_invalid_size, status);
195+
}
196+
else
197+
{
198+
EXPECT_EQ(rocblas_status_success, status);
199+
}
178200
}
179201
}
180202

clients/gtest/set_get_vector_gtest.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ Representative sampling is sufficient, endless brute-force sampling is not neces
4343
const
4444
int M_range[] = { 600, 6000000 };
4545

46-
//vector of vector, each triple is a {incx, incy, incd};
46+
//vector of vector, each triple is a {incx, incy, incb};
4747
//add/delete this list in pairs, like {1, 1, 1}
4848
const
49-
vector<vector<int>> incx_incy_incd_range = {
49+
vector<vector<int>> incx_incy_incb_range = {
5050
{1, 1, 1},
5151
{1, 1, 2},
5252
{1, 1, 3},
@@ -87,17 +87,17 @@ Arguments setup_set_get_vector_arguments(set_get_vector_tuple tup)
8787
{
8888

8989
int M = std::get<0>(tup);
90-
vector<int> incx_incy_incd = std::get<1>(tup);
90+
vector<int> incx_incy_incb = std::get<1>(tup);
9191

9292
Arguments arg;
9393

9494
// see the comments about vector_size_range above
9595
arg.M = M;
9696

9797
// see the comments about matrix_size_range above
98-
arg.incx = incx_incy_incd[0];
99-
arg.incy = incx_incy_incd[1];
100-
arg.incd = incx_incy_incd[2];
98+
arg.incx = incx_incy_incb[0];
99+
arg.incy = incx_incy_incb[1];
100+
arg.incb = incx_incy_incb[2];
101101

102102
return arg;
103103
}
@@ -126,19 +126,28 @@ TEST_P(set_vector_get_vector_gtest, float)
126126
rocblas_status status = testing_set_get_vector<float>( arg );
127127

128128
// if not success, then the input argument is problematic, so detect the error message
129-
if(status != rocblas_status_success){
130-
if( arg.M < 0 ){
129+
if(status != rocblas_status_success)
130+
{
131+
if( arg.M < 0 )
132+
{
131133
EXPECT_EQ(rocblas_status_invalid_size, status);
132134
}
133-
else if(arg.incx <= 0){
135+
else if(arg.incx <= 0)
136+
{
134137
EXPECT_EQ(rocblas_status_invalid_size, status);
135138
}
136-
else if(arg.incy <= 0){
139+
else if(arg.incy <= 0)
140+
{
137141
EXPECT_EQ(rocblas_status_invalid_size, status);
138142
}
139-
else if(arg.incd <= 0){
143+
else if(arg.incb <= 0)
144+
{
140145
EXPECT_EQ(rocblas_status_invalid_size, status);
141146
}
147+
else
148+
{
149+
EXPECT_EQ(rocblas_status_success, status);
150+
}
142151
}
143152
}
144153

@@ -150,6 +159,6 @@ TEST_P(set_vector_get_vector_gtest, float)
150159
INSTANTIATE_TEST_CASE_P(rocblas_auxiliary_small,
151160
set_vector_get_vector_gtest,
152161
Combine(
153-
ValuesIn(M_range), ValuesIn(incx_incy_incd_range)
162+
ValuesIn(M_range), ValuesIn(incx_incy_incb_range)
154163
)
155164
);

clients/include/arg_check.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
void set_get_matrix_arg_check(rocblas_status status, rocblas_int rows, rocblas_int cols, rocblas_int lda, rocblas_int ldb, rocblas_int ldc);
3939

40+
void set_get_vector_arg_check(rocblas_status status, rocblas_int M, rocblas_int incx, rocblas_int incy, rocblas_int incd);
41+
4042
void gemv_ger_arg_check(rocblas_status status, rocblas_int M, rocblas_int N, rocblas_int lda, rocblas_int incx, rocblas_int incy);
4143

4244
void gemm_arg_check(rocblas_status status, rocblas_int M, rocblas_int N, rocblas_int K, rocblas_int lda, rocblas_int ldb, rocblas_int ldc);

0 commit comments

Comments
 (0)