Skip to content

Commit ac8041c

Browse files
committed
Fixed logProbability and wrote a test for it
1 parent b52684f commit ac8041c

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

src/NgramModel-Tests/NgramModelTest.class.st

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,25 @@ NgramModelTest >> testTrainedModelCounts [
151151
self assert: (model countOfNgram: ngram5) equals: 1.
152152
]
153153

154+
{ #category : #tests }
155+
NgramModelTest >> testTrainedModelLogProbabilityOfText [
156+
| model text |
157+
model := NgramModel order: 2.
158+
159+
text := 'lorem ipsum ipsum ipsum dolor'.
160+
"<s> lorem 1
161+
lorem ipsum 1
162+
ipsum ipsum 2/3
163+
ipsum ipsum 2/3
164+
ipsum dolor 1/3
165+
dolor <s> 1"
166+
167+
model trainOnSentence: text.
168+
self
169+
assert: (model logProbabilityOfText: text)
170+
closeTo: (1 log + 1 log + (2/3) log + (2/3) log + (1/3) log + 1 log) asFloat.
171+
]
172+
154173
{ #category : #tests }
155174
NgramModelTest >> testTrainedModelProbabilitiesOfNgrams [
156175
| model text ngram1 ngram2 ngram3 ngram4 ngram5 |
@@ -204,11 +223,17 @@ NgramModelTest >> testTrainedModelSelfProbabilityOfText [
204223
model := NgramModel order: 2.
205224

206225
text := 'lorem ipsum ipsum ipsum dolor'.
207-
226+
"<s> lorem 1
227+
lorem ipsum 1
228+
ipsum ipsum 2/3
229+
ipsum ipsum 2/3
230+
ipsum dolor 1/3
231+
dolor <s> 1"
232+
208233
model trainOnSentence: text.
209234
self
210235
assert: (model probabilityOfText: text)
211-
closeTo: 2/3 * 2/3 * 1/3 asFloat.
236+
closeTo: (1 * 1 * 2/3 * 2/3 * 1/3 * 1) asFloat.
212237
]
213238

214239
{ #category : #tests }

src/NgramModel/NgramModel.class.st

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ NgramModel >> logProbabilityOfText: aString [
8484
| ngrams |
8585
ngrams := aString ngramsWithDefaultPadding: self order.
8686
^ (ngrams collect: [ :ngram | self probabilityOfNgram: ngram ])
87-
inject: 1 into: [ :prod :each | prod + each log ].
87+
inject: 0 into: [ :sum :each | sum + each log ].
8888
]
8989

9090
{ #category : #generation }

0 commit comments

Comments
 (0)