Skip to content

Commit d689214

Browse files
Merge pull request #210 from YvanGuifo/PMTensorClass
Addition of class and operations on tensors
2 parents 1ab9ebc + 0a77c47 commit d689214

File tree

3 files changed

+368
-0
lines changed

3 files changed

+368
-0
lines changed

src/Math-Matrix/PMNDArray.class.st

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
Class {
2+
#name : #PMNDArray,
3+
#superclass : #Object,
4+
#instVars : [
5+
'array',
6+
'shape',
7+
'first',
8+
'strides'
9+
],
10+
#category : #'Math-Matrix'
11+
}
12+
13+
{ #category : #'instance creation' }
14+
PMNDArray class >> fromNestedArray: anArray [
15+
16+
^ self new array: anArray flattened withShape: (self shape:anArray)
17+
]
18+
19+
{ #category : #'instance creation' }
20+
PMNDArray class >> fromScalar: anInteger [
21+
22+
^ self new array: {anInteger} withShape: #( )
23+
]
24+
25+
{ #category : #'instance creation' }
26+
PMNDArray class >> newWith: anInteger [
27+
28+
^ self new array: {anInteger} withShape: #( )
29+
]
30+
31+
{ #category : #accessing }
32+
PMNDArray class >> shape: anArray [
33+
34+
anArray isArray ifFalse:[^#()]
35+
ifTrue:[
36+
^ {anArray size}, (self shape: anArray first)
37+
]
38+
39+
]
40+
41+
{ #category : #comparing }
42+
PMNDArray >> = anArray [
43+
44+
^ array = anArray array & (first = anArray first)
45+
& (strides = anArray strides) & (shape = anArray shape)
46+
]
47+
48+
{ #category : #'as yet unclassified' }
49+
PMNDArray >> array: aFlatArray withShape: aShape [
50+
51+
array := aFlatArray.
52+
shape := aShape copy.
53+
self updateFirst.
54+
shape ifNotEmpty: [ self updateStrides]
55+
]
56+
57+
{ #category : #private }
58+
PMNDArray >> asArray [
59+
^array
60+
]
61+
62+
{ #category : #public }
63+
PMNDArray >> at: coords [
64+
65+
| position |
66+
position := self flattenedIndexOf: coords.
67+
^ array at: position
68+
]
69+
70+
{ #category : #initialization }
71+
PMNDArray >> at: coords put: aValue [
72+
73+
array at: (self flattenedIndexOf: coords) put: aValue
74+
]
75+
76+
{ #category : #accessing }
77+
PMNDArray >> first [
78+
^first
79+
]
80+
81+
{ #category : #'primitives - file' }
82+
PMNDArray >> flattenedIndexOf: coords [
83+
84+
| position |
85+
position := 1.
86+
coords withIndexDo: [ :elt :i |
87+
position := (elt - 1) * (strides at: i) + position ].
88+
^ position
89+
]
90+
91+
{ #category : #'as yet unclassified' }
92+
PMNDArray >> fromNestedArray: aFlatArray withShape: aShape [
93+
94+
array := aFlatArray.
95+
shape := aShape copy.
96+
self updateFirst.
97+
shape ifNotEmpty: [ self updateStrides]
98+
]
99+
100+
{ #category : #accessing }
101+
PMNDArray >> rank [
102+
^ shape size
103+
]
104+
105+
{ #category : #'as yet unclassified' }
106+
PMNDArray >> reshape: aNewShape [
107+
108+
^ self viewWithShape: aNewShape.
109+
110+
]
111+
112+
{ #category : #accessing }
113+
PMNDArray >> shape [
114+
115+
^ shape
116+
]
117+
118+
{ #category : #accessing }
119+
PMNDArray >> size [
120+
121+
^ shape inject: 1 into: [ :each :product | each * product].
122+
123+
]
124+
125+
{ #category : #accessing }
126+
PMNDArray >> strides [
127+
^strides
128+
]
129+
130+
{ #category : #'as yet unclassified' }
131+
PMNDArray >> updateFirst [
132+
133+
first := Array new: shape size withAll: 1
134+
]
135+
136+
{ #category : #'as yet unclassified' }
137+
PMNDArray >> updateStrides [
138+
139+
strides := Array new: shape size.
140+
strides at: shape size put: 1.
141+
((shape size -1) to: 1 by: -1) do: [ :i |
142+
strides at: i put: ((strides at: i + 1) * (shape at: i+1))]
143+
]
144+
145+
{ #category : #'as yet unclassified' }
146+
PMNDArray >> view [
147+
148+
"Share only the array"
149+
150+
^ self viewWithShape: shape
151+
]
152+
153+
{ #category : #'as yet unclassified' }
154+
PMNDArray >> viewWithShape: aNewShape [
155+
156+
^ PMNDArray new array: self asArray withShape: aNewShape
157+
158+
]
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
Class {
2+
#name : #PMNDArrayTest,
3+
#superclass : #TestCase,
4+
#category : #'Math-Matrix'
5+
}
6+
7+
{ #category : #tests }
8+
PMNDArrayTest >> testArray [
9+
10+
| t1 t2 |
11+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 ) #( 5 6 7 8 ) ).
12+
self assert: t1 asArray equals: #( 1 2 3 4 5 6 7 8 ).
13+
14+
t2 := PMNDArray fromNestedArray: #( #( #( 1 2 ) #( 3 4 ) ) #( #( 5 6 ) #( 7 8 ) )
15+
#( #( 9 10 ) #( 11 12 ) ) #( #( 13 14 ) #( 15 16 ) ) ).
16+
self
17+
assert: t2 asArray
18+
equals: #( 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 )
19+
]
20+
21+
{ #category : #tests }
22+
PMNDArrayTest >> testAt [
23+
24+
| t1 t2 |
25+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 ) #( 5 6 7 8 ) ).
26+
self assert: (t1 at: #( 2 2 )) equals: 6.
27+
28+
t2 := PMNDArray fromNestedArray: #( #( #( 1 2 ) #( 3 4 ) ) #( #( 5 6 ) #( 7 8 ) )
29+
#( #( 9 10 ) #( 11 12 ) ) #( #( 13 14 ) #( 15 16 ) ) ).
30+
self assert: (t2 at: #( 3 2 1 )) equals: 11.
31+
32+
self should:[t1 at: #( 4 4 )] raise:Error
33+
]
34+
35+
{ #category : #tests }
36+
PMNDArrayTest >> testAtPut [
37+
38+
| t1 t2 |
39+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 ) #( 5 6 7 8 ) ).
40+
t1 at: #( 2 2 ) put: 3.
41+
self assert: (t1 at: #( 2 2 ) ) equals: 3.
42+
43+
t2 := PMNDArray fromNestedArray: #( #( #( 1 2 ) #( 3 4 ) )
44+
#( #( 5 6 ) #( 7 8 ) )
45+
#( #( 9 10 ) #( 11 12 ) )
46+
#( #( 13 14 ) #( 15 16 ) ) ).
47+
t2 at: #( 2 2 1) put: 10.
48+
self assert: (t2 at: #( 2 2 1 )) equals: 10
49+
]
50+
51+
{ #category : #tests }
52+
PMNDArrayTest >> testCreateScalarNDArray [
53+
54+
| s |
55+
s := PMNDArray fromScalar: 2.
56+
self assert: (s at: #( )) equals: 2.
57+
self should: [ s at: #( 1 1 ) ] raise: Error.
58+
self assert: s rank equals: 0.
59+
s at: #( ) put: 1.
60+
self assert: (s at: #( )) equals: 1.
61+
self assert: s shape equals: #( ).
62+
self assert: s size equals: 1
63+
]
64+
65+
{ #category : #tests }
66+
PMNDArrayTest >> testFirst [
67+
68+
| a b |
69+
a := PMNDArray fromNestedArray: (1 to: 6) asArray.
70+
self assert: a first equals: #( 1).
71+
b := a reshape: #( 3 2 ).
72+
self assert: b first equals: #( 1 1 )
73+
]
74+
75+
{ #category : #tests }
76+
PMNDArrayTest >> testFlattenedIndexOf [
77+
78+
| t1 t2 |
79+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 ) #( 5 6 7 8) #( 9 10 11 12)).
80+
self assert: (t1 flattenedIndexOf: #( 3 2 )) equals: 10.
81+
82+
t2 := PMNDArray fromNestedArray: #( #( #( 1 2 ) #( 3 4 ) ) #( #( 5 6 ) #( 7 8 ) )
83+
#( #( 9 10 ) #( 11 12 ) ) #( #( 13 14 ) #( 15 16 ) ) ).
84+
self assert: (t2 flattenedIndexOf: #( 1 2 2 )) equals: 4
85+
]
86+
87+
{ #category : #tests }
88+
PMNDArrayTest >> testFromNestedArray [
89+
90+
| t1 t2 |
91+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 )
92+
#( 5 6 7 8 ) ).
93+
self assert: t1 class equals: PMNDArray.
94+
95+
t2 := PMNDArray fromNestedArray: #( #( #( 1 1 ) #( 2 2 ) )
96+
#( #( 3 3 ) #( 4 4 ) )
97+
#( #( 4 4 ) #( 4 4 ) )
98+
#( #( 4 4 ) #( 4 4 ) ) ).
99+
self assert: t2 class equals: PMNDArray.
100+
101+
]
102+
103+
{ #category : #tests }
104+
PMNDArrayTest >> testRank [
105+
106+
| t1 t2 |
107+
108+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 ) #( 5 6 7 8 ) ).
109+
self assert: t1 rank equals: 2.
110+
111+
t2 := PMNDArray fromNestedArray: #( #( #( 1 2 ) #( 3 4 ) ) #( #( 5 6 ) #( 7 8 ) )
112+
#( #( 9 10 ) #( 11 12 ) ) #( #( 13 14 ) #( 15 16 ) ) ).
113+
self assert: t2 rank equals: 3
114+
]
115+
116+
{ #category : #tests }
117+
PMNDArrayTest >> testReshape [
118+
119+
| t t1 |
120+
t := PMNDArray fromNestedArray: #( #( 0 1 ) #( 2 3 ) #( 4 5 ) ).
121+
t1 := t reshape: #( 2 3 ).
122+
123+
self assert: t shape equals: #( 3 2 ).
124+
self assert: t1 shape equals: #( 2 3 ).
125+
self assert: t1 asArray == t asArray equals: true
126+
]
127+
128+
{ #category : #tests }
129+
PMNDArrayTest >> testShape [
130+
131+
| t1 t2 |
132+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 )
133+
#( 5 6 7 8 ) ).
134+
self assert: t1 shape equals: #( 2 4 ).
135+
136+
t2 := PMNDArray fromNestedArray: #( #( #( 1 1 ) #( 2 2 ) )
137+
#( #( 3 3 ) #( 4 4 ) )
138+
#( #( 4 4 ) #( 4 4 ) )
139+
#( #( 4 4 ) #( 4 4 ) ) ).
140+
self assert: t2 shape equals: #( 4 2 2 )
141+
]
142+
143+
{ #category : #tests }
144+
PMNDArrayTest >> testSize [
145+
146+
| t1 t2 |
147+
t1 := PMNDArray fromNestedArray: #( #( 1 2 3 4 ) #( 5 6 7 8 ) ).
148+
self assert: t1 size equals: 8.
149+
150+
t2 := PMNDArray fromNestedArray: #( #( #( 1 2 ) #( 3 4 ) ) #( #( 5 6 ) #( 7 8 ) )
151+
#( #( 9 10 ) #( 11 12 ) ) #( #( 13 14 ) #( 15 16 ) ) ).
152+
self assert: t2 size equals: 16
153+
]
154+
155+
{ #category : #tests }
156+
PMNDArrayTest >> testStrides [
157+
158+
| a b |
159+
a := PMNDArray fromNestedArray: (1 to: 24) asArray.
160+
self assert: a strides equals: #( 1 ).
161+
b := a reshape: #( 4 6 ).
162+
self assert: b strides equals: #( 6 1 ).
163+
b := a reshape: #( 6 4 ).
164+
self assert: b strides equals: #( 4 1 ).
165+
self assert: (b flattenedIndexOf: #( 4 2 )) equals: 14.
166+
b := a reshape: #( 3 4 2 ).
167+
self assert: b strides equals: #( 8 2 1 ).
168+
self assert: (b flattenedIndexOf: #( 3 2 1)) equals: 19.
169+
b := a reshape: #( 2 3 4 ).
170+
self assert: b strides equals: #( 12 4 1 ).
171+
self assert: (b flattenedIndexOf: #( 2 2 3 )) equals: 19
172+
]
173+
174+
{ #category : #tests }
175+
PMNDArrayTest >> testView [
176+
177+
| t t1 |
178+
t := PMNDArray fromNestedArray:
179+
#( #( 10 11 12 ) #( 13 14 15 ) #( 16 17 18 ) #( #( 20 21 22 )
180+
#( 23 24 25 ) #( 26 27 28 ) )
181+
#( #( 30 31 32 ) #( 33 34 35 ) #( 36 37 38 ) ) ).
182+
t1 := t view.
183+
self assert: t asArray == t1 asArray equals: true.
184+
self assert: t shape equals: t1 shape.
185+
self assert: t shape == t1 shape equals: false.
186+
self assert: t strides equals: t1 strides.
187+
self assert: t strides == t1 strides equals: false.
188+
self assert: t first equals: t1 first.
189+
self assert: t first == t1 first equals: false.
190+
191+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Class {
2+
#name : #ShapeMismatch,
3+
#superclass : #Error,
4+
#category : #'Math-Matrix'
5+
}
6+
7+
{ #category : #accessing }
8+
ShapeMismatch >> messageText [
9+
"Overwritten to initialiaze the message text to a standard text if it has not yet been set"
10+
11+
^ messageText ifNil: [ messageText := self standardMessageText ]
12+
]
13+
14+
{ #category : #printing }
15+
ShapeMismatch >> standardMessageText [
16+
"Generate a standard textual description"
17+
18+
^ 'Tensor shapes do not match'
19+
]

0 commit comments

Comments
 (0)