13
13
# limitations under the License.
14
14
15
15
import argparse
16
- import logging
17
16
import os
18
17
import random
19
18
import time
20
19
from functools import partial
21
20
22
21
import numpy as np
23
22
import paddle
23
+ from paddle .incubate import asp
24
24
from paddle .io import DataLoader
25
- from paddlenlp .datasets import load_dataset
26
-
27
25
from paddle .metric import Accuracy
28
- from paddlenlp .data import Stack , Tuple , Pad
29
- from paddlenlp .data .sampler import SamplerHelper
30
- from paddlenlp .transformers import BertForSequenceClassification , BertTokenizer
31
- from paddlenlp .transformers import ErnieForSequenceClassification , ErnieTokenizer
32
- from paddlenlp .transformers import LinearDecayWithWarmup
26
+
27
+ from paddlenlp .data import Pad , Stack , Tuple
28
+ from paddlenlp .datasets import load_dataset
33
29
from paddlenlp .metrics import Mcc , PearsonAndSpearman
30
+ from paddlenlp .transformers import (
31
+ BertForSequenceClassification ,
32
+ BertTokenizer ,
33
+ ErnieForSequenceClassification ,
34
+ ErnieTokenizer ,
35
+ LinearDecayWithWarmup ,
36
+ )
34
37
from paddlenlp .utils .log import logger
35
38
36
- from paddle .static import sparsity
37
-
38
39
METRIC_CLASSES = {
39
40
"cola" : Mcc ,
40
41
"sst-2" : Accuracy ,
@@ -253,16 +254,18 @@ def do_train(args):
253
254
254
255
train_ds = train_ds .map (trans_func , lazy = True )
255
256
256
- batchify_fn = lambda samples , fn = Tuple (
257
- Pad (axis = 0 , pad_val = tokenizer .pad_token_id ), # input
258
- Pad (axis = 0 , pad_val = tokenizer .pad_token_type_id ), # token_type
259
- Stack (dtype = "int64" if train_ds .label_list else "float32" ), # label
260
- ): fn (samples )
257
+ def batchify_fn (
258
+ samples ,
259
+ fn = Tuple (
260
+ Pad (axis = 0 , pad_val = tokenizer .pad_token_id ), # input
261
+ Pad (axis = 0 , pad_val = tokenizer .pad_token_type_id ), # token_type
262
+ Stack (dtype = "int64" if train_ds .label_list else "float32" ), # label
263
+ ),
264
+ ):
265
+ return fn (samples )
261
266
262
267
train_batch_sampler = paddle .io .BatchSampler (train_ds , batch_size = args .batch_size , shuffle = True )
263
268
264
- feed_list_name = []
265
-
266
269
# Define the input data and create the train/dev data_loader
267
270
with paddle .static .program_guard (main_program , startup_program ):
268
271
[input_ids , token_type_ids , labels ] = create_data_holder (args .task_name )
@@ -343,10 +346,10 @@ def do_train(args):
343
346
344
347
# Keep Pooler and task-specific layer dense.
345
348
# Please note, excluded_layers must be set before calling `optimizer.minimize()`.
346
- sparsity .set_excluded_layers (main_program , [model .bert .pooler .dense .full_name (), model .classifier .full_name ()])
347
- # Calling sparsity .decorate() to wrap minimize() in optimizer, which
349
+ asp .set_excluded_layers (main_program , [model .bert .pooler .dense .full_name (), model .classifier .full_name ()])
350
+ # Calling asp .decorate() to wrap minimize() in optimizer, which
348
351
# will insert necessary masking operations for ASP workflow.
349
- optimizer = sparsity .decorate (optimizer )
352
+ optimizer = asp .decorate (optimizer )
350
353
optimizer .minimize (loss )
351
354
352
355
# Create the metric pass for the validation
@@ -364,8 +367,8 @@ def do_train(args):
364
367
paddle .static .set_program_state (main_program , reset_state_dict )
365
368
366
369
# Pruning model to be 2:4 sparse pattern
367
- # Must call `exe.run(startup_program)` first before calling `sparsity .prune_model`
368
- sparsity .prune_model (place , main_program )
370
+ # Must call `exe.run(startup_program)` first before calling `asp .prune_model`
371
+ asp .prune_model (place , main_program )
369
372
370
373
global_step = 0
371
374
tic_train = time .time ()
0 commit comments