Skip to content

Commit 1b498f1

Browse files
authored
* fix_demo * fix_demo
1 parent 5142fc6 commit 1b498f1

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

demo/auto_compression/hyperparameter_tutorial.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ Quantization:
2727
蒸馏参数主要设置蒸馏节点(`node`)和教师预测模型路径,如下所示:
2828
```yaml
2929
Distillation:
30-
# ahpha: 蒸馏loss所占权重;可输入多个数值,支持不同节点之间使用不同的ahpha值
31-
lambda: 1.0
30+
# alpha: 蒸馏loss所占权重;可输入多个数值,支持不同节点之间使用不同的ahpha值
31+
alpha: 1.0
3232
# loss: 蒸馏loss算法;可输入多个loss,支持不同节点之间使用不同的loss算法
3333
loss: l2
3434
# node: 蒸馏节点,即某层输出的变量名称,可以选择:

demo/auto_compression/image_classification/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,6 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
111111
train_config=train_config,
112112
train_dataloader=train_dataloader,
113113
eval_callback=eval_function,
114-
eval_dataloader=reader_wrapper(eval_reader(data_dir, args.batch_size)), args.input_name)
114+
eval_dataloader=reader_wrapper(eval_reader(data_dir, args.batch_size), args.input_name))
115115

116116
ac.compress()

0 commit comments

Comments
 (0)