Skip to content

Commit be6e5e2

Browse files
committed
further templatize span comparison example
1 parent 1b9b64d commit be6e5e2

File tree

2 files changed

+90
-58
lines changed

2 files changed

+90
-58
lines changed

nlp_span_comparison/README.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,22 @@
55
This notebook can be used as a template for comparing NLP models that predict
66
spans. Given two models and a sequence of text examples from which to extract
77
spans, the notebook presents the model predictions on each example and
8-
lets you indicate which model yielded the better prediction.
8+
lets you indicate which model yielded the better prediction. Your preferences
9+
are saved (and loaded) from storage, letting you use this as a real tool.
10+
11+
To use this notebook for your own data, just replace the implementations
12+
of the following three functions:
13+
14+
* `load_examples`: Load your own examples (strings) from a file or database.
15+
* `model_a_predictor`: Predict a span for a given example using model A.
16+
* `model_b_predictor`: Predict a span for a given example using model B.
17+
18+
The notebook keeps track of your preferences in a JSON file. To track
19+
preferences in a different way, such as in a database, replace the implementations
20+
of the following two functions:
21+
22+
* `load_choices`
23+
* `save_choices`
924

1025
## Running this notebook
1126

nlp_span_comparison/nlp_span_comparison.py

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import marimo
99

10-
__generated_with = "0.10.12"
10+
__generated_with = "0.10.13"
1111
app = marimo.App()
1212

1313

@@ -17,6 +17,45 @@ def _(mo):
1717
return
1818

1919

20+
@app.cell
21+
def _(textwrap, urllib):
22+
# Modify this function to load your own examples
23+
def load_examples():
24+
hamlet_url = "https://gist.githubusercontent.com/provpup/2fc41686eab7400b796b/raw/b575bd01a58494dfddc1d6429ef0167e709abf9b/hamlet.txt"
25+
26+
with urllib.request.urlopen(hamlet_url) as f:
27+
HAMLET = f.read().decode("utf-8")
28+
29+
return [
30+
textwrap.dedent(block).strip()[:1000]
31+
for block in HAMLET.split("\n\n")
32+
if block
33+
]
34+
return (load_examples,)
35+
36+
37+
@app.cell
38+
def _(random):
39+
# Replace with your predictor for model A
40+
def model_a_predictor(text: str) -> tuple[int, int]:
41+
random.seed(len(text))
42+
start = random.randint(0, len(text) - 2)
43+
end = random.randint(start + 1, len(text) - 1)
44+
return start, end
45+
return (model_a_predictor,)
46+
47+
48+
@app.cell
49+
def _(random):
50+
# Replace with your predictor for model B
51+
def model_b_predictor(text: str) -> tuple[int, int]:
52+
random.seed(len(text) / 2)
53+
start = random.randint(0, len(text) - 2)
54+
end = random.randint(start + 1, len(text) - 1)
55+
return start, end
56+
return (model_b_predictor,)
57+
58+
2059
@app.cell(hide_code=True)
2160
def _(mo):
2261
mo.md(
@@ -28,6 +67,12 @@ def _(mo):
2867
return
2968

3069

70+
@app.cell
71+
def _(load_examples):
72+
EXAMPLES = load_examples()
73+
return (EXAMPLES,)
74+
75+
3176
@app.cell
3277
def _(NUMBER_OF_EXAMPLES, mo):
3378
index = mo.ui.number(
@@ -73,28 +118,32 @@ def _(index):
73118

74119
@app.cell
75120
def _(CHOICES_PATH, get_choices, index, mo, write_choices):
76-
preference = get_choices()[index.value]["model"]
77-
mo.stop(preference is None, mo.md("**Choose the better model**.").center())
78-
write_choices(get_choices(), CHOICES_PATH)
79-
mo.md(f"You prefer **model {preference}**.").center()
80-
return (preference,)
121+
def _():
122+
preference = get_choices()[index.value]["model"]
123+
mo.stop(preference is None, mo.md("**Choose the better model**.").center())
124+
125+
write_choices(get_choices(), CHOICES_PATH)
126+
return mo.md(f"You prefer **model {preference}**.").center()
127+
128+
_()
129+
return
81130

82131

83132
@app.cell
84133
def _(annotate, mo):
85134
mo.hstack(
86135
[
87-
mo.md(annotate("Model A", [0, len("Model A")], "yellow")),
88-
mo.md(annotate("Model B", [0, len("Model B")], "lightblue")),
136+
annotate("Model A", [0, len("Model A")], "yellow"),
137+
annotate("Model B", [0, len("Model B")], "lightblue"),
89138
],
90139
justify="space-around",
91140
)
92141
return
93142

94143

95144
@app.cell
96-
def _(CHOICES_PATH, PARAGRAPHS, load_choices, mo):
97-
get_choices, set_choices = mo.state(load_choices(CHOICES_PATH, len(PARAGRAPHS)))
145+
def _(CHOICES_PATH, EXAMPLES, load_choices, mo):
146+
get_choices, set_choices = mo.state(load_choices(CHOICES_PATH, len(EXAMPLES)))
98147
return get_choices, set_choices
99148

100149

@@ -122,20 +171,24 @@ def _(index, mo, set_choices):
122171

123172

124173
@app.cell
125-
def _(PARAGRAPHS, SPANS, annotate, index, mo):
126-
model_A_prediction = mo.md(
127-
annotate(PARAGRAPHS[index.value], SPANS[index.value][0], color="yellow")
174+
def _(EXAMPLES, annotate, index, model_a_predictor, model_b_predictor):
175+
_example = EXAMPLES[index.value]
176+
177+
model_A_prediction = annotate(
178+
_example, model_a_predictor(_example), color="yellow"
128179
)
129180

130-
model_B_prediction = mo.md(
131-
annotate(PARAGRAPHS[index.value], SPANS[index.value][1], color="lightblue")
181+
model_B_prediction = annotate(
182+
_example, model_b_predictor(_example), color="lightblue"
132183
)
133184
return model_A_prediction, model_B_prediction
134185

135186

136187
@app.cell
137188
def _(mo, model_A_prediction, model_B_prediction):
138-
mo.hstack([model_A_prediction, model_B_prediction], gap=2, justify="space-around")
189+
mo.hstack(
190+
[model_A_prediction, model_B_prediction], gap=2, justify="space-around"
191+
)
139192
return
140193

141194

@@ -165,72 +218,37 @@ def load_choices(path, number_of_examples):
165218
assert len(choices) == number_of_examples
166219
return choices
167220

221+
168222
def write_choices(choices, path):
169223
# Trunacate notes
170224
with open(path, "w") as f:
171225
f.write(json.dumps(choices))
172-
173226
return load_choices, write_choices
174227

175228

176229
@app.cell
177-
def _(PARAGRAPHS, random):
178-
random.seed(0)
179-
180-
def predict_spans(text):
181-
first = [random.randint(0, len(text) - 2)]
182-
first.append(random.randint(first[0] + 1, len(text) - 1))
183-
second = [random.randint(0, len(text) - 2)]
184-
second.append(random.randint(second[0] + 1, len(text) - 1))
185-
186-
return first, second
187-
188-
SPANS = [predict_spans(p) for p in PARAGRAPHS]
189-
return SPANS, predict_spans
190-
191-
192-
@app.cell
193-
def _(HAMLET, textwrap):
194-
PARAGRAPHS = [
195-
textwrap.dedent(block).strip()[:1000] for block in HAMLET.split("\n\n") if block
196-
]
197-
return (PARAGRAPHS,)
198-
199-
200-
@app.cell
201-
def _():
230+
def _(mo):
202231
def annotate(text, span, color):
203232
mark_start = f"<mark style='background-color:{color}'>"
204-
return (
233+
return mo.md(
205234
text[: span[0]]
206235
+ mark_start
207236
+ text[span[0] : span[1]]
208237
+ "</mark>"
209238
+ text[span[1] :]
210239
)
211-
212240
return (annotate,)
213241

214242

215243
@app.cell
216-
def _(PARAGRAPHS):
217-
NUMBER_OF_EXAMPLES = len(PARAGRAPHS)
244+
def _(EXAMPLES):
245+
NUMBER_OF_EXAMPLES = len(EXAMPLES)
218246
return (NUMBER_OF_EXAMPLES,)
219247

220248

221-
@app.cell
222-
def _(urllib):
223-
_hamlet_url = "https://gist.githubusercontent.com/provpup/2fc41686eab7400b796b/raw/b575bd01a58494dfddc1d6429ef0167e709abf9b/hamlet.txt"
224-
225-
with urllib.request.urlopen(_hamlet_url) as f:
226-
HAMLET = f.read().decode("utf-8")
227-
return HAMLET, f
228-
229-
230249
@app.cell
231250
def _():
232251
import marimo as mo
233-
234252
return (mo,)
235253

236254

@@ -241,7 +259,6 @@ def _():
241259
import random
242260
import textwrap
243261
import urllib
244-
245262
return json, os, random, textwrap, urllib
246263

247264

0 commit comments

Comments
 (0)