Skip to content

Commit 7db0dfe

Browse files
committed
fix dataloader
1 parent 864d531 commit 7db0dfe

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

core/trainers/framework/runner.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,20 @@ def _executor_dataloader_train(self, model_dict, context):
174174
fetch_list=metrics_varnames,
175175
return_numpy=False)
176176

177+
metrics = [batch_id]
178+
metrics_rets = [
179+
as_numpy(metrics_tensor)
180+
for metrics_tensor in metrics_tensors
181+
]
182+
metrics.extend(metrics_rets)
183+
177184
if batch_id % fetch_period == 0 and batch_id != 0:
178-
metrics = [batch_id]
179185
end_time = time.time()
180186
seconds = end_time - begin_time
181-
metrics.extend([seconds])
187+
metrics_logging = metrics[:]
188+
metrics_logging = metrics.insert(1, seconds)
182189
begin_time = end_time
183190

184-
metrics_rets = [
185-
as_numpy(metrics_tensor)
186-
for metrics_tensor in metrics_tensors
187-
]
188-
metrics.extend(metrics_rets)
189191
logging.info(metrics_format.format(*metrics))
190192
batch_id += 1
191193
except fluid.core.EOFException:

models/multitask/mmoe/config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe"
1717
dataset:
1818
- name: dataset_train
1919
batch_size: 5
20-
type: QueueDataset
20+
type: DataLoader # or QueueDataset
2121
data_path: "{workspace}/data/train"
2222
data_converter: "{workspace}/census_reader.py"
2323
- name: dataset_infer
2424
batch_size: 5
25-
type: QueueDataset
25+
type: DataLoader # or QueueDataset
2626
data_path: "{workspace}/data/train"
2727
data_converter: "{workspace}/census_reader.py"
2828

@@ -48,7 +48,7 @@ runner:
4848
save_inference_interval: 4
4949
save_checkpoint_path: "increment"
5050
save_inference_path: "inference"
51-
print_interval: 10
51+
print_interval: 1
5252
- name: infer_runner
5353
class: infer
5454
init_model_path: "increment/1"

0 commit comments

Comments
 (0)