Skip to content

Commit 52aa628

Browse files
author
Qiang Kou
committed
close #365
1 parent eda71fd commit 52aa628

File tree

4 files changed

+66
-0
lines changed

4 files changed

+66
-0
lines changed

ChangeLog

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
2015-12-04 Qiang Kou <[email protected]>
2+
3+
* inst/include/Rcpp/vector/Matrix.h: Add math operators between matrix and scalar
4+
* inst/unitTests/runit.Matrix.R: Unit tests
5+
* inst/unitTests/cpp/Matrix.cpp: Unit tests
6+
17
2015-11-27 JJ Allaire <[email protected]>
28

39
* src/attributes.cpp: Avoid invalid function names when

inst/include/Rcpp/vector/Matrix.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,39 @@ inline std::ostream &operator<<(std::ostream & s, const Matrix<REALSXP, StorageP
247247
return s;
248248
}
249249

250+
#define RCPP_GENERATE_MATRIX_SCALAR_OPERATOR(__OPERATOR__) \
251+
template <int RTYPE, template <class> class StoragePolicy > \
252+
inline Matrix<RTYPE, StoragePolicy> operator __OPERATOR__ (const Matrix<RTYPE, StoragePolicy> &lhs, \
253+
const typename Matrix<RTYPE, StoragePolicy>::stored_type &rhs) { \
254+
Vector<RTYPE, StoragePolicy> v = static_cast<const Vector<RTYPE, StoragePolicy> &>(lhs) __OPERATOR__ rhs; \
255+
v.attr("dim") = Vector<INTSXP>::create(lhs.nrow(), lhs.ncol()); \
256+
return as< Matrix<RTYPE, StoragePolicy> >(v); \
257+
}
258+
259+
RCPP_GENERATE_MATRIX_SCALAR_OPERATOR(+)
260+
RCPP_GENERATE_MATRIX_SCALAR_OPERATOR(-)
261+
RCPP_GENERATE_MATRIX_SCALAR_OPERATOR(*)
262+
RCPP_GENERATE_MATRIX_SCALAR_OPERATOR(/)
263+
264+
#undef RCPP_GENERATE_MATRIX_SCALAR_OPERATOR
265+
266+
#define RCPP_GENERATE_SCALAR_MATRIX_OPERATOR(__OPERATOR__) \
267+
template <int RTYPE, template <class> class StoragePolicy > \
268+
inline Matrix<RTYPE, StoragePolicy> operator __OPERATOR__ (const typename Matrix<RTYPE, StoragePolicy>::stored_type &lhs, \
269+
const Matrix<RTYPE, StoragePolicy> &rhs) { \
270+
Vector<RTYPE, StoragePolicy> v = static_cast<const Vector<RTYPE, StoragePolicy> &>(rhs); \
271+
v = lhs __OPERATOR__ v; \
272+
v.attr("dim") = Vector<INTSXP>::create(rhs.nrow(), rhs.ncol()); \
273+
return as< Matrix<RTYPE, StoragePolicy> >(v); \
274+
}
275+
276+
RCPP_GENERATE_SCALAR_MATRIX_OPERATOR(+)
277+
RCPP_GENERATE_SCALAR_MATRIX_OPERATOR(-)
278+
RCPP_GENERATE_SCALAR_MATRIX_OPERATOR(*)
279+
RCPP_GENERATE_SCALAR_MATRIX_OPERATOR(/)
280+
281+
#undef RCPP_GENERATE_SCALAR_MATRIX_OPERATOR
282+
250283
template<template <class> class StoragePolicy >
251284
inline std::ostream &operator<<(std::ostream & s, const Matrix<INTSXP, StoragePolicy> & rhs) {
252285
typedef Matrix<INTSXP, StoragePolicy> MATRIX;

inst/unitTests/cpp/Matrix.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,23 @@ NumericMatrix transposeNumeric(const NumericMatrix & x) {
311311
CharacterMatrix transposeCharacter(const CharacterMatrix & x) {
312312
return transpose(x);
313313
}
314+
315+
// [[Rcpp::export]]
316+
NumericMatrix matrix_scalar_plus(const NumericMatrix & x, int y) {
317+
return x + y;
318+
}
319+
320+
// [[Rcpp::export]]
321+
NumericMatrix matrix_scalar_plus2(const NumericMatrix & x, int y) {
322+
return y + x;
323+
}
324+
325+
// [[Rcpp::export]]
326+
NumericMatrix matrix_scalar_divide(const NumericMatrix & x, int y) {
327+
return x / y;
328+
}
329+
330+
// [[Rcpp::export]]
331+
NumericMatrix matrix_scalar_divide2(const NumericMatrix & x, int y) {
332+
return y / x;
333+
}

inst/unitTests/runit.Matrix.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,4 +242,11 @@ if (.runThisTest) {
242242
checkEquals(transposeCharacter(M), t(M), msg="character transpose with row and colnames")
243243
}
244244

245+
test.Matrix.Scalar.op <- function() {
246+
M <- matrix(c(1:12), 3, 4)
247+
checkEquals(matrix_scalar_plus(M, 2), M + 2, msg="matrix + scalar")
248+
checkEquals(matrix_scalar_plus2(M, 2), 2 + M, msg="scalar + matrix")
249+
checkEquals(matrix_scalar_divide(M, 2), M / 2, msg="matrix / scalar")
250+
checkEquals(matrix_scalar_divide2(M, 2), 2 / M, msg="scalar / matrix")
251+
}
245252
}

0 commit comments

Comments
 (0)