Skip to content

Commit 9bf32cd

Browse files
committed
Overloaded sum method now accepts the out type argument, similar to Matlab's sum. Also performs summation similarly to Matlab's sum. Thanks to @marcsous.
1 parent 902a2e5 commit 9bf32cd

File tree

1 file changed

+58
-23
lines changed

1 file changed

+58
-23
lines changed

MappedTensor.m

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -826,34 +826,69 @@ function disp(mtVar)
826826
'*** MappedTensor: Concatenation is not supported for MappedTensor objects.');
827827
end
828828

829-
%% sum - METHOD Overloaded sum function for usage "sum(mtVar <, dim>)"
829+
%% sum - METHOD Overloaded sum function for usage "sum(mtVar <, dim, outtype>)"
830830
function [tFinalSum] = sum(mtVar, varargin)
831831
% - Get tensor size
832832
vnTensorSize = size(mtVar);
833833

834-
if (exist('varargin', 'var') && ~isempty(varargin))
835-
% - Check varargin for string parameters and discard
836-
vbIsString = cellfun(@ischar, varargin);
837-
varargin = varargin(~vbIsString);
838-
839-
% - Too many arguments?
840-
if (numel(varargin) > 1)
841-
error('MappedTensor:sum:InvalidArguments', ...
842-
'*** MappedTensor/sum: Too many arguments were supplied.');
834+
% - By default, sum along first non-singleton dimension
835+
DEF_nDim = find(vnTensorSize > 1, 1, 'first');
836+
837+
% - By default, accumulate in a double tensor
838+
DEF_strReturnClass = 'double';
839+
840+
% - Check arguments and apply defaults
841+
if (nargin > 3)
842+
error('MappedTensor:sum:InvalidArguments', ...
843+
'*** MappedTensor/sum: Too many arguments were supplied.');
844+
845+
elseif (nargin == 3)
846+
847+
elseif (nargin == 2)
848+
if (ischar(varargin{1}))
849+
varargin{2} = varargin{1};
850+
varargin{1} = DEF_nDim;
851+
852+
else
853+
varargin{2} = DEF_strReturnClass;
843854
end
844-
845-
% - Was a dimension specified?
846-
if (~isnumeric(varargin{1}) || numel(varargin{1}) > 1)
847-
error('MappedTensor:sum:InvalidArguments', ...
848-
'*** MappedTensor/sum: ''dim'' must be supplied as a scalar number.');
855+
856+
elseif (nargin == 1)
857+
varargin{1} = DEF_nDim;
858+
varargin{2} = DEF_strReturnClass;
859+
end
860+
861+
% - Was a valid dimension specified?
862+
try
863+
validateattributes(varargin{1}, {'numeric'}, {'positive', 'integer', 'scalar'});
864+
catch
865+
error('MappedTensor:sum:InvalidArguments', ...
866+
'*** MappedTensor/sum: ''dim'' must be supplied as a positive scalar number.');
867+
end
868+
nDim = varargin{1};
869+
870+
% - Was a valid output argument type specified?
871+
try
872+
strReturnClass = validatestring(lower(varargin{2}), {'native', 'double', 'default'});
873+
catch
874+
error('MappedTensor:sum:InvalidArguments', ...
875+
'*** MappedTensor/sum: ''outtype'' must be one of {''double'', ''native'', ''default''}.');
876+
end
877+
878+
% - Get the class for the summation matrix
879+
if (strcmp(strReturnClass, 'native'))
880+
% - Logicals are always summed in a double tensor
881+
if (islogical(mtVar))
882+
strOutputClass = 'double';
883+
else
884+
strOutputClass = mtVar.strClass;
849885
end
886+
887+
elseif (strcmp(strReturnClass, 'default'))
888+
strOutputClass = DEF_strReturnClass;
850889

851-
% - Record dimension to sum along
852-
nDim = varargin{1};
853-
854-
else
855-
% - By default, sum along first non-singleton dimension
856-
nDim = find(vnTensorSize > 1, 1, 'first');
890+
else %if (strcmp(strReturnClass, 'double'))
891+
strOutputClass = strReturnClass;
857892
end
858893

859894
% -- Sum in chunks to avoid allocating full tensor
@@ -878,7 +913,7 @@ function disp(mtVar)
878913
end
879914

880915
% -- Perform sum by taking dimensions in turn
881-
tFinalSum = zeros(vnSumSize);
916+
tFinalSum = zeros(vnSumSize, strOutputClass);
882917

883918
% - Construct referencing structures
884919
sSourceRef = substruct('()', ':');
@@ -898,7 +933,7 @@ function disp(mtVar)
898933
% - Call subsasgn, subsref and sum to process data
899934
sSourceRef.subs = cellTheseSourceIndices;
900935
sDestRef.subs = cellTheseDestIndices;
901-
tFinalSum = subsasgn(tFinalSum, sDestRef, subsref(tFinalSum, sDestRef) + sum(subsref(mtVar, sSourceRef), nDim));
936+
tFinalSum = subsasgn(tFinalSum, sDestRef, subsref(tFinalSum, sDestRef) + sum(subsref(mtVar, sSourceRef), nDim, strReturnClass));
902937

903938
% - Increment first non-max index
904939
nIncrementDim = find(vnSplitIndices <= vnNumDivisions, 1, 'first');

0 commit comments

Comments
 (0)