Skip to content

Commit 13222c7

Browse files
authored
fix: update how ink makes prediction (#71)
1 parent 5cae28f commit 13222c7

File tree

5 files changed

+48
-29
lines changed

5 files changed

+48
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ exclude = ["truelearn.tests*"] # exclude tests from build artifacts
8282
# e.g. pip install truelearn[dev] will install the tests dependencies
8383
[project.optional-dependencies]
8484
tests = ["pytest>=7.2.1", "pytest-cov>=4.0.0", "pytest-socket>=0.6.0"]
85-
linters = ["prospector[with_bandit,with_mypy]>=1.8.4"]
85+
linters = ["prospector[with_bandit,with_mypy]==1.8.4"]
8686
docs = ["sphinx>=5.3.0", "furo>=2023.03.27","sphinx_copybutton>=0.5.1", "sphinx-gallery>=0.12.2", "Pillow>=9.4.0"]
8787
dev = ["truelearn[tests, linters, docs]","black>=22.12.0"]
8888

truelearn/learning/_ink_classifier.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ class INKClassifier(BaseClassifier):
6464
... ink_classifier.predict_proba(event)
6565
... )
6666
...
67-
True 0.64105...
68-
False 0.44438...
69-
True 0.64909...
67+
True 0.64839...
68+
False 0.43767...
69+
True 0.65660...
7070
>>> ink_classifier.get_params(deep=False) # doctest:+ELLIPSIS
7171
{...'learner_meta_weights': LearnerMetaWeights(novelty_weights=Weights(\
7272
mean=0.20461..., variance=0.45871...), interest_weights=Weights(\
@@ -174,9 +174,9 @@ def __eval_matching_quality(
174174
bias_weights.mean * pred_bias,
175175
]
176176
team_learner_variance = [
177-
novelty_weights.variance * pred_novelty,
178-
interest_weights.variance * pred_interest,
179-
bias_weights.variance * pred_bias,
177+
novelty_weights.variance * (pred_novelty**2),
178+
interest_weights.variance * (pred_interest**2),
179+
bias_weights.variance * (pred_bias**2),
180180
]
181181
team_content_mean = [self._threshold]
182182
team_content_variance = []
@@ -203,21 +203,13 @@ def __create_env(self):
203203
def __update_weights(
204204
self,
205205
x: EventModel,
206-
pred_novelty: float,
207-
pred_interest: float,
208206
pred_actual: float,
209207
) -> None:
210208
"""Update the weights of novelty, interest and bias.
211209
212210
Args:
213211
x:
214212
A representation of the learning event.
215-
pred_novelty:
216-
The predicted probability of the learner's engagement by using
217-
NoveltyClassifier.
218-
pred_interest:
219-
The predicted probability of the learner's engagement by using
220-
InterestClassifier.
221213
pred_actual:
222214
Whether the learner actually engages in the given event. This value is
223215
either 0 or 1.
@@ -228,6 +220,9 @@ def __update_weights(
228220
if self._greedy and cur_pred == pred_actual:
229221
return
230222

223+
pred_novelty = self._novelty_classifier.predict_proba(x)
224+
pred_interest = self._interest_classifier.predict_proba(x)
225+
231226
# train
232227
env = self.__create_env()
233228
team_experts = (
@@ -278,11 +273,7 @@ def __update_weights(
278273
def fit(self, x: EventModel, y: bool) -> Self:
279274
self._novelty_classifier.fit(x, y)
280275
self._interest_classifier.fit(x, y)
281-
282-
pred_novelty = self._novelty_classifier.predict_proba(x)
283-
pred_interest = self._interest_classifier.predict_proba(x)
284-
285-
self.__update_weights(x, pred_novelty, pred_interest, y)
276+
self.__update_weights(x, y)
286277
return self
287278

288279
def predict(self, x: EventModel) -> bool:

truelearn/tests/test_learning.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,11 @@ def test_ink_classifier_customize(self, train_cases, test_events):
799799
for event, label in zip(train_events, train_labels):
800800
classifier.fit(event, label)
801801

802-
expected_results = [0.4155257653300731, 0.3792233211000749, 0.35213145076551466]
802+
expected_results = [
803+
0.40575267541878457,
804+
0.36519542301026875,
805+
0.33362493980730495,
806+
]
803807
actual_results = [classifier.predict_proba(event) for event in test_events]
804808

805809
check_farray_close(actual_results, expected_results)
@@ -849,7 +853,7 @@ def test_ink_classifier(self, train_cases, test_events):
849853
for event, label in zip(train_events, train_labels):
850854
classifier.fit(event, label)
851855

852-
expected_results = [0.3943943468622016, 0.3536982390875026, 0.33082714771211985]
856+
expected_results = [0.3844070661899784, 0.3398805698754434, 0.3133264788862059]
853857
actual_results = [classifier.predict_proba(event) for event in test_events]
854858

855859
check_farray_close(actual_results, expected_results)

truelearn/tests/test_utils_visualisations.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
# pylint: disable=missing-function-docstring,missing-class-docstring
1+
# pylint: disable=missing-function-docstring,missing-class-docstring,line-too-long
22
import functools
33
import random
44
import pathlib
5-
import filecmp
65
import types
76
import os
87
import sys
@@ -113,28 +112,42 @@ def file_comparison(plotter_type: str, config: Optional[Dict[str, Dict]] = None)
113112
For matplotlib type, the method will test `.png`.
114113
config:
115114
A dictionary containing the configuration for each extension.
115+
E.g. config={".png": {...}, ".json": {...}, ...}
116116
"""
117117
config = config or {}
118118

119119
if plotter_type == "plotly":
120120
# only support html and json for plotly
121121
# because the backend engine that plotly uses
122-
# to generate imgaes is platform dependent
122+
# to generate images is platform dependent
123123
# Therefore, to be able to provide consistent
124124
# and replicable tests, we test against json and html.
125125
extensions = {
126-
".json": config.get(".json", {}),
126+
".json": {
127+
**config.get(".json", {}),
128+
# to generate files with cross-platform consistent encoding
129+
"encoding": "utf-8",
130+
},
127131
".html": {
128132
**config.get(".html", {}),
129133
# overwrite settings for div_id and include_plotlyjs
130134
# as they directly affect the generated output
131135
"div_id": UUID,
132136
"include_plotlyjs": "https://cdn.plot.ly/plotly-2.20.0.min.js",
137+
# to generate files with cross-platform consistent encoding
138+
"encoding": "utf-8",
133139
},
134140
}
135141

136142
def file_cmp_func(filename1, filename2):
137-
return filecmp.cmp(filename1, filename2)
143+
# since we use utf-8 to save all text files
144+
# we can safely open them with utf-8 here
145+
with open(filename1, "rt", encoding="utf-8") as f1, open(
146+
filename2, "rt", encoding="utf-8"
147+
) as f2:
148+
# line by line comparison, ignore the differences in newline characters
149+
# see https://docs.python.org/3/library/functions.html#open-newline-parameter # noqa
150+
return f1.readlines() == f2.readlines()
138151

139152
elif plotter_type == "matplotlib":
140153
extensions = {

truelearn/utils/visualisations/_base.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,16 @@ def savefig(self, file: str, **kargs):
197197
The default width of the image in the HTML file.
198198
default_height:
199199
The default height of the image in the HTML file.
200+
encoding:
201+
The encoding of the saved HTML file. If unspecified,
202+
the encoding will be utf-8.
200203
201204
If you want to export a JSON file, you can optionally pass in
202205
pretty:
203206
Whether the saved JSON representation should be pretty-printed.
207+
encoding:
208+
The encoding of the saved JSON file. If unspecified,
209+
the encoding will be utf-8.
204210
205211
If you want to export an image file, you can optionally pass in
206212
width:
@@ -213,10 +219,15 @@ def savefig(self, file: str, **kargs):
213219
to find out more supported arguments.
214220
"""
215221
if file.endswith(".html"):
216-
self.figure.write_html(file=file, **kargs)
222+
encoding = kargs.pop("encoding", None) or "utf-8"
223+
with open(file, mode="wt", encoding=encoding) as f:
224+
self.figure.write_html(file=f, **kargs)
217225
return
226+
218227
if file.endswith(".json"):
219-
self.figure.write_json(file=file, **kargs)
228+
encoding = kargs.pop("encoding", None) or "utf-8"
229+
with open(file, mode="wt", encoding=encoding) as f:
230+
self.figure.write_json(file=f, **kargs)
220231
return
221232

222233
self.figure.write_image(file=file, **kargs)

0 commit comments

Comments
 (0)