Skip to content

Commit 024af6b

Browse files
authored
check for invalid shapes in strassen multiplication (#208)
* check for invalid shapes in strassen multiplication Strassen multiplication is only possible when the matrix dimensions are a power of two. * reorder variables in assert condition * fix indentation
1 parent 67ad541 commit 024af6b

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/Math-Matrix/PMMatrix.class.st

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,8 +950,13 @@ PMMatrix >> squared [
950950
PMMatrix >> strassenProductWithMatrix: aMatrix [
951951
"Private"
952952
| matrixSplit selfSplit p1 p2 p3 p4 p5 p6 p7 |
953+
953954
( self numberOfRows > 2 and: [ self numberOfColumns > 2])
954955
ifFalse:[ ^self class rows: ( aMatrix rowsCollect: [ :row | self columnsCollect: [ :col | row * col]])].
956+
957+
self assert: [ self numberOfRows isPowerOfTwo and: self numberOfColumns isPowerOfTwo ] description: 'Matrix size should be a power of two'.
958+
self assert: [ aMatrix numberOfRows isPowerOfTwo and: aMatrix numberOfColumns isPowerOfTwo ] description: 'Matrix size should be a power of two'.
959+
955960
selfSplit := self split.
956961
matrixSplit := aMatrix split.
957962
p1 := ( ( selfSplit at: 2) - ( selfSplit at: 4)) strassenProductWithMatrix: ( matrixSplit at: 1).

src/Math-Tests-Matrix/PMMatrixTest.class.st

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,29 @@ PMMatrixTest >> testSkalarMultiplication [
737737
self assert: r class equals: PMSymmetricMatrix
738738
]
739739

740+
{ #category : #tests }
741+
PMMatrixTest >> testStrassenProductWithMatrix [
742+
743+
| aPMMatrix expected strassenProduct |
744+
745+
aPMMatrix := PMMatrix rows: #( (1 2 3 4) (1 0 1 2) (2 3 1 4) (4 3 2 2)).
746+
expected := PMMatrix rows: #( (30 12 27 24) (12 6 11 10) (27 11 30 27) (24 10 27 33)).
747+
strassenProduct := aPMMatrix transpose strassenProductWithMatrix: aPMMatrix.
748+
self assert: strassenProduct equals: expected.
749+
]
750+
751+
{ #category : #tests }
752+
PMMatrixTest >> testStrassenProductWithMatrixWithInvalidShapes [
753+
754+
| aPMMatrix |
755+
756+
"all the dimension of the matrices should be a power of 2 for strassen multiplication"
757+
aPMMatrix := PMMatrix rows: #( (1 2 3) (1 0 1) (2 3 1) (2 0 1)).
758+
759+
self should: [ aPMMatrix strassenProductWithMatrix: aPMMatrix transpose ] raise: AssertionFailure.
760+
761+
]
762+
740763
{ #category : #tests }
741764
PMMatrixTest >> testSymmetric [
742765
|a m|

0 commit comments

Comments
 (0)