Skip to content

Commit b3283d1

Browse files
Updated QAT Walkthrough ntoebook
Signed-off-by: Farshad Ghodsian <[email protected]>
1 parent 776302e commit b3283d1

File tree

1 file changed

+145
-20
lines changed

1 file changed

+145
-20
lines changed

examples/llm_qat/notebooks/QAT_Walkthrough.ipynb

Lines changed: 145 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
},
9292
{
9393
"cell_type": "code",
94-
"execution_count": 6,
94+
"execution_count": 1,
9595
"id": "6d25c2b1-a68b-4748-ac29-e8a893ce1762",
9696
"metadata": {},
9797
"outputs": [],
@@ -109,7 +109,7 @@
109109
},
110110
{
111111
"cell_type": "code",
112-
"execution_count": 14,
112+
"execution_count": 2,
113113
"id": "0ec71181-770a-4ee6-8760-c62cfab8340f",
114114
"metadata": {},
115115
"outputs": [],
@@ -129,14 +129,14 @@
129129
},
130130
{
131131
"cell_type": "code",
132-
"execution_count": 18,
132+
"execution_count": 3,
133133
"id": "5f946576-83ac-45b5-a290-9a2167193e3d",
134134
"metadata": {},
135135
"outputs": [
136136
{
137137
"data": {
138138
"application/vnd.jupyter.widget-view+json": {
139-
"model_id": "37c5f366ef204794bad4711ae6056d6c",
139+
"model_id": "9d21656629a64d6187b68dc703cb57c7",
140140
"version_major": 2,
141141
"version_minor": 0
142142
},
@@ -151,8 +151,7 @@
151151
"source": [
152152
"model = AutoModelForCausalLM.from_pretrained(model_name).cuda()\n",
153153
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
154-
"tokenizer.pad_token = tokenizer.eos_token\n",
155-
"tokenizer.padding_side = \"left\" # Setting this as tokenizer with the right padding_side may impact calibration accuracy."
154+
"tokenizer.pad_token = tokenizer.eos_token"
156155
]
157156
},
158157
{
@@ -165,10 +164,19 @@
165164
},
166165
{
167166
"cell_type": "code",
168-
"execution_count": 19,
167+
"execution_count": 4,
169168
"id": "f3b618e9-fdee-46b2-8d7e-f11f1f7ada8d",
170169
"metadata": {},
171-
"outputs": [],
170+
"outputs": [
171+
{
172+
"name": "stderr",
173+
"output_type": "stream",
174+
"text": [
175+
"/home/fghodsian/.venv/jupyter/lib/python3.12/site-packages/modelopt/torch/utils/dataset_utils.py:157: UserWarning: Tokenizer with the right padding_side may impact calibration accuracy. Recommend set to left\n",
176+
" warn(\n"
177+
]
178+
}
179+
],
172180
"source": [
173181
"\n",
174182
"# Calibration dataloader\n",
@@ -201,7 +209,7 @@
201209
},
202210
{
203211
"cell_type": "code",
204-
"execution_count": 17,
212+
"execution_count": 5,
205213
"id": "51c0c1bb-2804-45ae-873f-e33388458e04",
206214
"metadata": {},
207215
"outputs": [
@@ -217,7 +225,7 @@
217225
"name": "stderr",
218226
"output_type": "stream",
219227
"text": [
220-
"100%|█████████████████████████████████████████████████████████████████████████████| 64/64 [01:14<00:00, 1.16s/it]\n"
228+
"100%|█████████████████████████████████████████████████████████████████████████████| 64/64 [01:13<00:00, 1.15s/it]\n"
221229
]
222230
}
223231
],
@@ -237,7 +245,7 @@
237245
},
238246
{
239247
"cell_type": "code",
240-
"execution_count": 13,
248+
"execution_count": 6,
241249
"id": "c1a15f93-ee06-42a5-ab3b-ca3428a62fe7",
242250
"metadata": {},
243251
"outputs": [],
@@ -248,27 +256,144 @@
248256
},
249257
{
250258
"cell_type": "code",
251-
"execution_count": 14,
252-
"id": "95411f4c-b1d3-4e82-9afb-2608bd21a9a4",
259+
"execution_count": 7,
260+
"id": "e5ff221a-d807-450b-a099-6481cb3b00d0",
261+
"metadata": {},
262+
"outputs": [],
263+
"source": [
264+
"from datasets import load_dataset\n",
265+
"from transformers import DataCollatorForLanguageModeling\n",
266+
"\n",
267+
"# Load training dataset (for demonstration, use cnn_dailymail \"train\" split)\n",
268+
"train_dataset = load_dataset(\"cnn_dailymail\", '3.0.0', split=\"train[:1000]\") # Smaller subset for example\n",
269+
"\n",
270+
"def preprocess_function(examples):\n",
271+
" # Concatenate the article and highlights for training\n",
272+
" inputs = [a + \" \" + h for a, h in zip(examples[\"article\"], examples[\"highlights\"])]\n",
273+
" model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=512)\n",
274+
" model_inputs[\"labels\"] = model_inputs[\"input_ids\"].copy() # Language modeling: teacher-forced\n",
275+
" return model_inputs\n",
276+
"\n",
277+
"tokenized_train = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)\n",
278+
"\n",
279+
"# Data collator (for causal language modeling)\n",
280+
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n"
281+
]
282+
},
283+
{
284+
"cell_type": "code",
285+
"execution_count": 8,
286+
"id": "0f78bdcf-e2fc-49bd-b5b7-79de7260068d",
253287
"metadata": {},
254288
"outputs": [
255289
{
256-
"ename": "SyntaxError",
257-
"evalue": "incomplete input (1262456024.py, line 1)",
258-
"output_type": "error",
259-
"traceback": [
260-
" \u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[31m \u001b[39m\u001b[31mtrainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module\u001b[39m\n ^\n\u001b[31mSyntaxError\u001b[39m\u001b[31m:\u001b[39m incomplete input\n"
290+
"name": "stderr",
291+
"output_type": "stream",
292+
"text": [
293+
"/tmp/ipykernel_2585829/1505564370.py:15: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
294+
" trainer = Trainer(\n"
261295
]
262296
}
263297
],
264298
"source": [
265-
"trainer = Trainer(model=model, processing_class=tokenizer, args=training_args, **data_module"
299+
"from transformers import TrainingArguments, Trainer\n",
300+
"\n",
301+
"training_args = TrainingArguments(\n",
302+
" output_dir=\"./qat_model_output\",\n",
303+
" per_device_train_batch_size=2,\n",
304+
" num_train_epochs=2,\n",
305+
" learning_rate=1e-5, # As recommended for QAT in README\n",
306+
" logging_steps=50,\n",
307+
" save_steps=200,\n",
308+
" save_total_limit=2,\n",
309+
" report_to=\"none\",\n",
310+
" fp16=False\n",
311+
")\n",
312+
"\n",
313+
"trainer = Trainer(\n",
314+
" model=model,\n",
315+
" args=training_args,\n",
316+
" train_dataset=tokenized_train,\n",
317+
" data_collator=data_collator,\n",
318+
" tokenizer=tokenizer,\n",
319+
")\n"
320+
]
321+
},
322+
{
323+
"cell_type": "code",
324+
"execution_count": 9,
325+
"id": "d6ad0ebd-3804-4264-95f6-b6522bbb5e90",
326+
"metadata": {},
327+
"outputs": [
328+
{
329+
"data": {
330+
"text/html": [
331+
"\n",
332+
" <div>\n",
333+
" \n",
334+
" <progress value='126' max='126' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
335+
" [126/126 04:30, Epoch 2/2]\n",
336+
" </div>\n",
337+
" <table border=\"1\" class=\"dataframe\">\n",
338+
" <thead>\n",
339+
" <tr style=\"text-align: left;\">\n",
340+
" <th>Step</th>\n",
341+
" <th>Training Loss</th>\n",
342+
" </tr>\n",
343+
" </thead>\n",
344+
" <tbody>\n",
345+
" <tr>\n",
346+
" <td>50</td>\n",
347+
" <td>0.268200</td>\n",
348+
" </tr>\n",
349+
" <tr>\n",
350+
" <td>100</td>\n",
351+
" <td>0.250100</td>\n",
352+
" </tr>\n",
353+
" </tbody>\n",
354+
"</table><p>"
355+
],
356+
"text/plain": [
357+
"<IPython.core.display.HTML object>"
358+
]
359+
},
360+
"metadata": {},
361+
"output_type": "display_data"
362+
},
363+
{
364+
"data": {
365+
"text/plain": [
366+
"TrainOutput(global_step=126, training_loss=0.25205686735728433, metrics={'train_runtime': 274.3474, 'train_samples_per_second': 7.29, 'train_steps_per_second': 0.459, 'total_flos': 4.6110257184768e+16, 'train_loss': 0.25205686735728433, 'epoch': 2.0})"
367+
]
368+
},
369+
"execution_count": 9,
370+
"metadata": {},
371+
"output_type": "execute_result"
372+
}
373+
],
374+
"source": [
375+
"trainer.train()"
376+
]
377+
},
378+
{
379+
"cell_type": "code",
380+
"execution_count": 11,
381+
"id": "b6ca7a04-163e-498f-9a83-f22718fa5141",
382+
"metadata": {},
383+
"outputs": [],
384+
"source": [
385+
"# Save quantizer state for later resume/deploy\n",
386+
"import modelopt.torch.opt as mto\n",
387+
"torch.save(mto.modelopt_state(model), \"modelopt_quantizer_states.pt\")\n",
388+
"\n",
389+
"# Save the final weights\n",
390+
"trainer.save_model(\"./qat_model_output\")"
266391
]
267392
},
268393
{
269394
"cell_type": "code",
270395
"execution_count": null,
271-
"id": "e5ff221a-d807-450b-a099-6481cb3b00d0",
396+
"id": "e69417a6-a4e4-4541-a0ee-8035cfe7df76",
272397
"metadata": {},
273398
"outputs": [],
274399
"source": []

0 commit comments

Comments
 (0)