Skip to content

Commit 269008b

Browse files
committed
Add proper renderer support with loss masking and LoRA rank config
1 parent a37a7c4 commit 269008b

File tree

2 files changed

+56
-18
lines changed

2 files changed

+56
-18
lines changed

data_loader.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
try:
1010
from tinker import types
11+
from tinker_cookbook import renderers
1112
except ImportError:
1213
types = None
14+
renderers = None
1315

1416

1517
class DataLoader:
@@ -94,18 +96,20 @@ def prepare_training_data(
9496
self,
9597
train_file: str,
9698
tokenizer: Any,
99+
renderer_name: str = "llama3",
97100
deduplicate: bool = True,
98101
) -> List[Any]:
99102
"""
100-
Load and convert training data into Tinker Datum objects.
103+
Load and convert training data into Tinker Datum objects using proper renderers.
101104
102105
Args:
103106
train_file: Path to training JSONL file.
104107
tokenizer: Tokenizer from Tinker training client.
108+
renderer_name: Name of the renderer to use (e.g., "llama3", "qwen3").
105109
deduplicate: Whether to deduplicate examples.
106110
107111
Returns:
108-
List of tinker.types.Datum objects.
112+
List of tinker.types.Datum objects with proper loss masking.
109113
"""
110114
if types is None:
111115
raise ImportError("tinker package required for data preparation")
@@ -127,29 +131,62 @@ def prepare_training_data(
127131
print(f"Deduplicated to {len(unique_examples)} unique examples")
128132
valid_examples = unique_examples
129133

134+
renderer = None
135+
if renderers is not None:
136+
try:
137+
renderer = renderers.get_renderer(renderer_name, tokenizer)
138+
print(f"Using {renderer_name} renderer for proper loss masking")
139+
except Exception as e:
140+
print(f"Warning: Could not load renderer, falling back to simple tokenization: {e}")
141+
130142
datums = []
131143
for ex in valid_examples:
132144
instruction = ex["instruction"]
133145
input_text = ex.get("input", "")
134146
output_text = ex["output"]
135147

136-
if input_text:
137-
prompt = f"{instruction}\n\nInput: {input_text}\n\nResponse:"
148+
if renderer is not None:
149+
user_content = f"{instruction}\n\nInput: {input_text}" if input_text else instruction
150+
messages = [
151+
{"role": "user", "content": user_content},
152+
{"role": "assistant", "content": output_text},
153+
]
154+
155+
try:
156+
tokens, weights = renderer.build_supervised_example(messages)
157+
token_list = tokens.to_ints() if hasattr(tokens, 'to_ints') else tokens
158+
159+
if len(token_list) > self.max_seq_length:
160+
continue
161+
162+
input_tokens = token_list[:-1]
163+
target_tokens = token_list[1:]
164+
weights = weights[1:]
165+
166+
datum = types.Datum(
167+
model_input=types.ModelInput.from_ints(tokens=input_tokens),
168+
loss_fn_inputs={"weights": weights, "target_tokens": target_tokens},
169+
)
170+
datums.append(datum)
171+
except Exception as e:
172+
print(f"Warning: Skipping example due to rendering error: {e}")
138173
else:
139-
prompt = f"{instruction}\n\nResponse:"
140-
141-
full_text = f"{prompt} {output_text}"
142-
143-
tokens = tokenizer.encode(full_text)
144-
if len(tokens) > self.max_seq_length:
145-
print(f"Warning: Skipping example with {len(tokens)} tokens (max: {self.max_seq_length})")
146-
continue
147-
148-
datum = types.Datum(
149-
model_input=tokens,
150-
loss_fn_inputs={"target": tokens},
151-
)
152-
datums.append(datum)
174+
if input_text:
175+
prompt = f"{instruction}\n\nInput: {input_text}\n\nResponse:"
176+
else:
177+
prompt = f"{instruction}\n\nResponse:"
178+
179+
full_text = f"{prompt} {output_text}"
180+
tokens = tokenizer.encode(full_text)
181+
182+
if len(tokens) > self.max_seq_length:
183+
continue
184+
185+
datum = types.Datum(
186+
model_input=tokens,
187+
loss_fn_inputs={"target": tokens},
188+
)
189+
datums.append(datum)
153190

154191
print(f"Prepared {len(datums)} training datums")
155192
return datums

trainer_with_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ async def async_main(config_path: str) -> None:
217217
train_file=config.train_file,
218218
tokenizer=tokenizer,
219219
max_seq_length=config.max_seq_length,
220+
renderer_name=config.renderer_name,
220221
deduplicate=True,
221222
)
222223
if not datums:

0 commit comments

Comments
 (0)