Skip to content

Commit 66168f6

Browse files
authored
Support string input for ofa export (PaddlePaddle#964)
* support string input for ofa export
1 parent a8c16a7 commit 66168f6

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

paddleslim/nas/ofa/ofa.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@
7979
DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields)
8080

8181

82+
def to_tensor(string_values, name="text"):
83+
"""
84+
Create the tensor that the value holds the list of string.
85+
NOTICE: The value will be holded in the cpu place.
86+
87+
Parameters:
88+
string_values(list[string]): The value will be setted to the tensor.
89+
name(string): The name of the tensor.
90+
"""
91+
tensor = paddle.Tensor(core.VarDesc.VarType.STRING, [], name,
92+
core.VarDesc.VarType.STRINGS, False)
93+
tensor.value().set_string_list(string_values)
94+
return tensor
95+
96+
8297
class OFABase(Layer):
8398
def __init__(self, model):
8499
super(OFABase, self).__init__()
@@ -531,6 +546,8 @@ def build_input(input_size, dtypes):
531546
dtype = dtypes[0]
532547
else:
533548
dtype = dtypes
549+
if dtype == core.VarDesc.VarType.STRINGS:
550+
return to_tensor([""])
534551
return paddle.cast(paddle.rand(list(input_size)), dtype)
535552
if isinstance(input_size, dict):
536553
inputs = {}

paddleslim/quant/quanter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def quant_post_static(
307307
quantize_model_path,
308308
batch_generator=None,
309309
sample_generator=None,
310+
data_loader=None,
310311
model_filename=None,
311312
params_filename=None,
312313
save_model_filename='__model__',
@@ -346,6 +347,9 @@ def quant_post_static(
346347
can be set. Beisdes, batch_generator supports lod tensor.
347348
sample_generator(Python Generator): The sample generator provides
348349
calibrate data for DataLoader, and it only returns a sample every time.
350+
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
351+
Generator or Dataloader provides calibrate data, and it could
352+
return a batch every time.
349353
model_filename(str, optional): The name of model file. If parameters
350354
are saved in separate files, set it as 'None'. Default: 'None'.
351355
params_filename(str, optional): The name of params file.
@@ -399,6 +403,7 @@ def quant_post_static(
399403
executor=executor,
400404
sample_generator=sample_generator,
401405
batch_generator=batch_generator,
406+
data_loader=data_loader,
402407
model_dir=model_dir,
403408
model_filename=model_filename,
404409
params_filename=params_filename,

0 commit comments

Comments
 (0)