Skip to content

Commit a45671c

Browse files
authored
featurization cleanup (#23)
* featurization cleanup 1 * tests are almost fixed * cleanup * private members in base * private calls in tests * commenting out invalid test * black * embed to __init__ * dead code cleanup * black
1 parent fce4bad commit a45671c

File tree

7 files changed

+186
-197
lines changed

7 files changed

+186
-197
lines changed

src/learn_to_pick/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
VwPolicy,
1313
VwLogger,
1414
embed,
15-
stringify_embedding,
1615
)
1716
from learn_to_pick.pick_best import (
1817
PickBest,

src/learn_to_pick/base.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def EmbedAndKeep(anything: Any) -> Any:
8787
# helper functions
8888

8989

90-
def stringify_embedding(embedding: List) -> str:
90+
def _stringify_embedding(embedding: List) -> str:
9191
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
9292

9393

94-
def parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
94+
def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
9595
return [parser.parse_line(line) for line in input_str.split("\n")]
9696

9797

@@ -116,20 +116,6 @@ def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]
116116
return based_on, to_select_from
117117

118118

119-
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]) -> Dict[str, Any]:
120-
"""
121-
go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed,
122-
then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
123-
"""
124-
125-
next_inputs = inputs.copy()
126-
for k, v in next_inputs.items():
127-
if isinstance(v, _ToSelectFrom) or isinstance(v, _BasedOn):
128-
if not isinstance(v.value, _Embed):
129-
next_inputs[k].value = EmbedAndKeep(v.value)
130-
return next_inputs
131-
132-
133119
# end helper functions
134120

135121

@@ -195,15 +181,15 @@ def predict(self, event: TEvent) -> Any:
195181

196182
text_parser = vw.TextFormatParser(self.workspace)
197183
return self.workspace.predict_one(
198-
parse_lines(text_parser, self.featurizer.format(event))
184+
_parse_lines(text_parser, self.featurizer.format(event))
199185
)
200186

201187
def learn(self, event: TEvent) -> None:
202188
import vowpal_wabbit_next as vw
203189

204190
vw_ex = self.featurizer.format(event)
205191
text_parser = vw.TextFormatParser(self.workspace)
206-
multi_ex = parse_lines(text_parser, vw_ex)
192+
multi_ex = _parse_lines(text_parser, vw_ex)
207193
self.workspace.learn_one(multi_ex)
208194

209195
def log(self, event: TEvent) -> None:
@@ -489,20 +475,13 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
489475
return {"picked": picked, "picked_metadata": event}
490476

491477

492-
def is_stringtype_instance(item: Any) -> bool:
493-
"""Helper function to check if an item is a string."""
494-
return isinstance(item, str) or (
495-
isinstance(item, _Embed) and isinstance(item.value, str)
496-
)
497-
498-
499-
def embed_string_type(
478+
def _embed_string_type(
500479
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
501480
) -> Dict[str, Union[str, List[str]]]:
502481
"""Helper function to embed a string or an _Embed object."""
503482
keep_str = ""
504483
if isinstance(item, _Embed):
505-
encoded = stringify_embedding(model.encode(item.value))
484+
encoded = _stringify_embedding(model.encode(item.value))
506485
if item.keep:
507486
keep_str = item.value.replace(" ", "_") + " "
508487
elif isinstance(item, str):
@@ -518,36 +497,36 @@ def embed_string_type(
518497
return {namespace: keep_str + encoded}
519498

520499

521-
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
500+
def _embed_dict_type(item: Dict, model: Any) -> Dict[str, Any]:
522501
"""Helper function to embed a dictionary item."""
523502
inner_dict: Dict = {}
524503
for ns, embed_item in item.items():
525504
if isinstance(embed_item, list):
526505
inner_dict[ns] = []
527506
for embed_list_item in embed_item:
528-
embedded = embed_string_type(embed_list_item, model, ns)
507+
embedded = _embed_string_type(embed_list_item, model, ns)
529508
inner_dict[ns].append(embedded[ns])
530509
else:
531-
inner_dict.update(embed_string_type(embed_item, model, ns))
510+
inner_dict.update(_embed_string_type(embed_item, model, ns))
532511
return inner_dict
533512

534513

535-
def embed_list_type(
514+
def _embed_list_type(
536515
item: list, model: Any, namespace: Optional[str] = None
537516
) -> List[Dict[str, Union[str, List[str]]]]:
538517
ret_list: List = []
539518
for embed_item in item:
540519
if isinstance(embed_item, dict):
541-
ret_list.append(embed_dict_type(embed_item, model))
520+
ret_list.append(_embed_dict_type(embed_item, model))
542521
elif isinstance(embed_item, list):
543-
item_embedding = embed_list_type(embed_item, model, namespace)
522+
item_embedding = _embed_list_type(embed_item, model, namespace)
544523
# Get the first key from the first dictionary
545524
first_key = next(iter(item_embedding[0]))
546525
# Group the values under that key
547526
grouping = {first_key: [item[first_key] for item in item_embedding]}
548527
ret_list.append(grouping)
549528
else:
550-
ret_list.append(embed_string_type(embed_item, model, namespace))
529+
ret_list.append(_embed_string_type(embed_item, model, namespace))
551530
return ret_list
552531

553532

@@ -569,10 +548,10 @@ def embed(
569548
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
570549
to_embed, str
571550
):
572-
return [embed_string_type(to_embed, model, namespace)]
551+
return [_embed_string_type(to_embed, model, namespace)]
573552
elif isinstance(to_embed, dict):
574-
return [embed_dict_type(to_embed, model)]
553+
return [_embed_dict_type(to_embed, model)]
575554
elif isinstance(to_embed, list):
576-
return embed_list_type(to_embed, model, namespace)
555+
return _embed_list_type(to_embed, model, namespace)
577556
else:
578557
raise ValueError("Invalid input format for embedding")

src/learn_to_pick/pick_best.py

Lines changed: 55 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
4+
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Iterable
5+
from itertools import chain
56
import os
67

78
from learn_to_pick import base
@@ -42,6 +43,28 @@ def __init__(
4243
self.based_on = based_on
4344

4445

46+
class VwTxt:
47+
@staticmethod
48+
def embedding(embedding: List[float]) -> str:
49+
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
50+
51+
@staticmethod
52+
def features(features: Union[str, List[str]]) -> str:
53+
return " ".join(features) if isinstance(features, list) else features
54+
55+
@staticmethod
56+
def _namespaces(ns: Iterable[Tuple[str, Union[str, List[str]]]]):
57+
return " ".join(f"|{k} {VwTxt.features(v)}" for k, v in ns)
58+
59+
@staticmethod
60+
def ns(ns: Union[Iterable[Tuple[str, Any]], List[Dict[str, Any]], Dict[str, Any]]):
61+
if isinstance(ns, List):
62+
ns = chain.from_iterable(map(dict.items, ns))
63+
if isinstance(ns, Dict):
64+
ns = ns.items()
65+
return VwTxt._namespaces(ns)
66+
67+
4568
class PickBestFeaturizer(base.Featurizer[PickBestEvent]):
4669
"""
4770
Text Featurizer class that embeds the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy
@@ -63,10 +86,6 @@ def __init__(
6386
self.model = model
6487
self.auto_embed = auto_embed
6588

66-
@staticmethod
67-
def _str(embedding: List[float]) -> str:
68-
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
69-
7089
def get_label(self, event: PickBestEvent) -> tuple:
7190
cost = None
7291
if event.selected:
@@ -148,70 +167,48 @@ def format_auto_embed_on(self, event: PickBestEvent) -> str:
148167
context_emb, action_embs = self.get_context_and_action_embeddings(event)
149168
indexed_dot_product = self.get_indexed_dot_product(context_emb, action_embs)
150169

151-
action_lines = []
170+
nactions = len(action_embs)
171+
172+
def _tolist(v):
173+
return v if isinstance(v, list) else [v]
174+
175+
labels = ["" for _ in range(nactions)]
176+
if cost is not None:
177+
labels[chosen_action] = f"{chosen_action}:{cost}:{prob} "
178+
179+
dotprods = [{} for _ in range(nactions)]
152180
for i, action in enumerate(action_embs):
153-
line_parts = []
154-
dot_prods = []
155-
if cost is not None and chosen_action == i:
156-
line_parts.append(f"{chosen_action}:{cost}:{prob}")
157-
for ns, action in action.items():
158-
line_parts.append(f"|{ns}")
159-
elements = action if isinstance(action, list) else [action]
160-
nsa = []
161-
for elem in elements:
162-
line_parts.append(f"{elem}")
163-
ns_a = f"{ns}={elem}"
164-
nsa.append(ns_a)
165-
for k, v in indexed_dot_product.items():
166-
dot_prods.append(v[ns_a])
167-
nsa_str = " ".join(nsa)
168-
line_parts.append(f"|# {nsa_str}")
169-
170-
line_parts.append(f"|dotprod {self._str(dot_prods)}")
171-
action_lines.append(" ".join(line_parts))
172-
173-
shared = []
181+
action["#"] = [f"{k}={v}" for k, _v in action.items() for v in _tolist(_v)]
182+
dotprods[i] = [
183+
v[f] for v in indexed_dot_product.values() for f in action["#"]
184+
]
185+
186+
actions_str = [
187+
f"{l}{VwTxt.ns(a)} |dotprod {VwTxt.embedding(dp)}"
188+
for l, a, dp in zip(labels, action_embs, dotprods)
189+
]
190+
174191
for item in context_emb:
175-
for ns, context in item.items():
176-
shared.append(f"|{ns}")
177-
elements = context if isinstance(context, list) else [context]
178-
nsc = []
179-
for elem in elements:
180-
shared.append(f"{elem}")
181-
nsc.append(f"{ns}={elem}")
182-
nsc_str = " ".join(nsc)
183-
shared.append(f"|@ {nsc_str}")
184-
185-
return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)
192+
item["@"] = [f"{k}={v}" for k, _v in item.items() for v in _tolist(_v)]
193+
shared_str = f"shared {VwTxt.ns(context_emb)}"
194+
195+
return "\n".join([shared_str] + actions_str)
186196

187197
def format_auto_embed_off(self, event: PickBestEvent) -> str:
188198
"""
189199
Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW
190200
"""
191201
chosen_action, cost, prob = self.get_label(event)
192202
context_emb, action_embs = self.get_context_and_action_embeddings(event)
203+
nactions = len(action_embs)
193204

194-
example_string = ""
195-
example_string += "shared "
196-
for context_item in context_emb:
197-
for ns, based_on in context_item.items():
198-
e = " ".join(based_on) if isinstance(based_on, list) else based_on
199-
example_string += f"|{ns} {e} "
200-
example_string += "\n"
205+
context_str = f"shared {VwTxt.ns(context_emb)}"
201206

202-
for i, action in enumerate(action_embs):
203-
if cost is not None and chosen_action == i:
204-
example_string += f"{chosen_action}:{cost}:{prob} "
205-
for ns, action_embedding in action.items():
206-
e = (
207-
" ".join(action_embedding)
208-
if isinstance(action_embedding, list)
209-
else action_embedding
210-
)
211-
example_string += f"|{ns} {e} "
212-
example_string += "\n"
213-
# Strip the last newline
214-
return example_string[:-1]
207+
labels = ["" for _ in range(nactions)]
208+
if cost is not None:
209+
labels[chosen_action] = f"{chosen_action}:{cost}:{prob} "
210+
actions_str = [f"{l}{VwTxt.ns(a)}" for a, l in zip(action_embs, labels)]
211+
return "\n".join([context_str] + actions_str)
215212

216213
def format(self, event: PickBestEvent) -> str:
217214
if self.auto_embed:

0 commit comments

Comments
 (0)