Skip to content

Commit 46ffd24

Browse files
Merge pull request #213 from YvanGuifo/AddProductOfNDArray
Addition of matrix product
2 parents 01ff4e3 + a1cc5e0 commit 46ffd24

File tree

3 files changed

+78
-4
lines changed

3 files changed

+78
-4
lines changed

src/Math-Matrix/PMNDArray.class.st

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ ifTrue:[
3939
]
4040

4141
{ #category : #comparing }
42-
PMNDArray >> = anArray [
42+
PMNDArray >> = aPMNDArray [
4343

44-
^ array = anArray array & (first = anArray first)
45-
& (strides = anArray strides) & (shape = anArray shape)
44+
^ array = aPMNDArray asArray & (first = aPMNDArray first)
45+
& (strides = aPMNDArray strides) & (shape = aPMNDArray shape)
4646
]
4747

4848
{ #category : #'as yet unclassified' }
@@ -97,6 +97,18 @@ PMNDArray >> fromNestedArray: aFlatArray withShape: aShape [
9797
shape ifNotEmpty: [ self updateStrides]
9898
]
9999

100+
{ #category : #operation }
101+
PMNDArray >> hadamardProduct: aPMNDArray [
102+
^ self with: aPMNDArray collect: [:a :b| a*b]
103+
104+
105+
]
106+
107+
{ #category : #testing }
108+
PMNDArray >> hasSameShapeAs: aPMNDArray [
109+
^ self shape = aPMNDArray shape
110+
]
111+
100112
{ #category : #accessing }
101113
PMNDArray >> rank [
102114
^ shape size
@@ -156,3 +168,13 @@ PMNDArray >> viewWithShape: aNewShape [
156168
^ PMNDArray new array: self asArray withShape: aNewShape
157169

158170
]
171+
172+
{ #category : #operation }
173+
PMNDArray >> with: aPMNDArray collect: aBlock [
174+
(self hasSameShapeAs: aPMNDArray)
175+
ifFalse:[ShapeMismatch signal ].
176+
^ self class new array: ((self asArray) with: (aPMNDArray asArray) collect: aBlock)
177+
withShape: self shape.
178+
179+
180+
]

src/Math-Matrix/PMNDArrayTest.class.st

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,58 @@ PMNDArrayTest >> testFromNestedArray [
100100

101101
]
102102

103+
{ #category : #tests }
104+
PMNDArrayTest >> testHadamardProduct [
105+
|t1 t2 expectedHadamardProduct|
106+
107+
t1 := PMNDArray fromNestedArray: #(
108+
#( 1 2 3 4 ) #( 5 6 7 8 )
109+
).
110+
t2 := PMNDArray fromNestedArray: #(
111+
#( 4 3 2 9 ) #( 9 7 6 5 )
112+
).
113+
expectedHadamardProduct := PMNDArray fromNestedArray:#(
114+
#( 4 6 6 36 ) #( 45 42 42 40 )
115+
).
116+
117+
self assert: (t1 hadamardProduct: t2) equals: expectedHadamardProduct.
118+
119+
t1 := PMNDArray fromNestedArray:
120+
#(
121+
#( #( 1 2 ) #( 3 4 ))
122+
#( #( 5 6 ) #( 7 8 ))
123+
#( #( 9 10) #( 11 12))
124+
).
125+
126+
t2 := PMNDArray fromNestedArray:
127+
#(
128+
#( #( 1 3 ) #( 2 1))
129+
#( #( 0 6 ) #( 3 8))
130+
#( #( 3 5) #( 1 10))
131+
).
132+
133+
expectedHadamardProduct := PMNDArray fromNestedArray:#(
134+
#( #( 1 6 ) #( 6 4 ))
135+
#( #( 0 36 ) #( 21 64 ))
136+
#( #( 27 50 ) #( 11 120 ))).
137+
138+
139+
self assert: (t1 hadamardProduct: t2) equals: expectedHadamardProduct
140+
141+
142+
143+
144+
145+
]
146+
147+
{ #category : #tests }
148+
PMNDArrayTest >> testHadamardProductWithDifferentShapesFails [
149+
|t1 t2|
150+
t1 := PMNDArray fromNestedArray: #( #( 1 2) #(3 4) #(5 6) #(7 8)).
151+
t2 := PMNDArray fromNestedArray: #( #( 4 3 2 9 ) #( 9 7 6 5 ) ).
152+
self should: [t1 hadamardProduct: t2] raise: ShapeMismatch
153+
]
154+
103155
{ #category : #tests }
104156
PMNDArrayTest >> testRank [
105157

src/Math-Matrix/ShapeMismatch.class.st

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ ShapeMismatch >> messageText [
1515
ShapeMismatch >> standardMessageText [
1616
"Generate a standard textual description"
1717

18-
^ 'Tensor shapes do not match'
18+
^ 'NDArray shapes do not match'
1919
]

0 commit comments

Comments
 (0)