Skip to content

Commit 81b78e6

Browse files
authored
replace static.sparsity with incubate.asp (#4186)
* replace static.sparsity with incubate.asp * code style format
1 parent c984ee7 commit 81b78e6

File tree

1 file changed

+25
-22
lines changed

1 file changed

+25
-22
lines changed

model_zoo/bert/static/run_glue_with_sparaity.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,29 @@
1313
# limitations under the License.
1414

1515
import argparse
16-
import logging
1716
import os
1817
import random
1918
import time
2019
from functools import partial
2120

2221
import numpy as np
2322
import paddle
23+
from paddle.incubate import asp
2424
from paddle.io import DataLoader
25-
from paddlenlp.datasets import load_dataset
26-
2725
from paddle.metric import Accuracy
28-
from paddlenlp.data import Stack, Tuple, Pad
29-
from paddlenlp.data.sampler import SamplerHelper
30-
from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer
31-
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
32-
from paddlenlp.transformers import LinearDecayWithWarmup
26+
27+
from paddlenlp.data import Pad, Stack, Tuple
28+
from paddlenlp.datasets import load_dataset
3329
from paddlenlp.metrics import Mcc, PearsonAndSpearman
30+
from paddlenlp.transformers import (
31+
BertForSequenceClassification,
32+
BertTokenizer,
33+
ErnieForSequenceClassification,
34+
ErnieTokenizer,
35+
LinearDecayWithWarmup,
36+
)
3437
from paddlenlp.utils.log import logger
3538

36-
from paddle.static import sparsity
37-
3839
METRIC_CLASSES = {
3940
"cola": Mcc,
4041
"sst-2": Accuracy,
@@ -253,16 +254,18 @@ def do_train(args):
253254

254255
train_ds = train_ds.map(trans_func, lazy=True)
255256

256-
batchify_fn = lambda samples, fn=Tuple(
257-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
258-
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
259-
Stack(dtype="int64" if train_ds.label_list else "float32"), # label
260-
): fn(samples)
257+
def batchify_fn(
258+
samples,
259+
fn=Tuple(
260+
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
261+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # token_type
262+
Stack(dtype="int64" if train_ds.label_list else "float32"), # label
263+
),
264+
):
265+
return fn(samples)
261266

262267
train_batch_sampler = paddle.io.BatchSampler(train_ds, batch_size=args.batch_size, shuffle=True)
263268

264-
feed_list_name = []
265-
266269
# Define the input data and create the train/dev data_loader
267270
with paddle.static.program_guard(main_program, startup_program):
268271
[input_ids, token_type_ids, labels] = create_data_holder(args.task_name)
@@ -343,10 +346,10 @@ def do_train(args):
343346

344347
# Keep Pooler and task-specific layer dense.
345348
# Please note, excluded_layers must be set before calling `optimizer.minimize()`.
346-
sparsity.set_excluded_layers(main_program, [model.bert.pooler.dense.full_name(), model.classifier.full_name()])
347-
# Calling sparsity.decorate() to wrap minimize() in optimizer, which
349+
asp.set_excluded_layers(main_program, [model.bert.pooler.dense.full_name(), model.classifier.full_name()])
350+
# Calling asp.decorate() to wrap minimize() in optimizer, which
348351
# will insert necessary masking operations for ASP workflow.
349-
optimizer = sparsity.decorate(optimizer)
352+
optimizer = asp.decorate(optimizer)
350353
optimizer.minimize(loss)
351354

352355
# Create the metric pass for the validation
@@ -364,8 +367,8 @@ def do_train(args):
364367
paddle.static.set_program_state(main_program, reset_state_dict)
365368

366369
# Pruning model to be 2:4 sparse pattern
367-
# Must call `exe.run(startup_program)` first before calling `sparsity.prune_model`
368-
sparsity.prune_model(place, main_program)
370+
# Must call `exe.run(startup_program)` first before calling `asp.prune_model`
371+
asp.prune_model(place, main_program)
369372

370373
global_step = 0
371374
tic_train = time.time()

0 commit comments

Comments
 (0)