Skip to content

Commit 994b52f

Browse files
committed
Add layers for save/load op
1 parent 92974d4 commit 994b52f

File tree

1 file changed

+69
-1
lines changed
  • python/paddle/fluid/layers

1 file changed

+69
-1
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3206,7 +3206,7 @@ def one_hot(input, depth):
32063206
operator.
32073207
32083208
Args:
3209-
input(Tensor/LodTensor): A Tensor/LodTensor of indices, last dimension must be 1.
3209+
input(variable): A Tensor/LodTensor of indices, last dimension must be 1.
32103210
depth(scalar): an interger defining the depth of the one hot dimension.
32113211
32123212
Returns:
@@ -3265,3 +3265,71 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
32653265
counter.stop_gradient = True
32663266

32673267
return counter
3268+
3269+
3270+
def save(x, file_path, overwrite=True):
3271+
"""
3272+
Saves a variable as a file.
3273+
3274+
Args:
3275+
x(variable): The Tensor/LoDTensor to be saved.
3276+
file_path(str): The file path where the variable will be saved.
3277+
overwrite(bool): Whether or not cover the given file when it has already
3278+
existed. If it's set 'False' and the file is existed, a runtime
3279+
error will be thrown.
3280+
"""
3281+
helper = LayerHelper("save", **locals())
3282+
helper.append_op(
3283+
type="save",
3284+
inputs={"input": x},
3285+
outputs={},
3286+
args={"file_path": file_path,
3287+
"overwrite": overwrite})
3288+
3289+
3290+
def save_combine(x, file_path, overwrite=True):
3291+
"""
3292+
Saves a variable as a file.
3293+
3294+
Args:
3295+
x(list): A list of Tensor/LoDTensor to be saved together in a single file.
3296+
file_path(str): The file path where variables will be saved.
3297+
overwrite(bool): Whether or not cover the given file when it has already
3298+
existed. If it's set 'False' and the file is existed, a runtime
3299+
error will be thrown.
3300+
"""
3301+
helper = LayerHelper("save_combine", **locals())
3302+
helper.append_op(
3303+
type="save_combine",
3304+
inputs={"input": x},
3305+
outputs={},
3306+
args={"file_path": file_path,
3307+
"overwrite": overwrite})
3308+
3309+
3310+
def load(out, file_path):
3311+
"""
3312+
Args:
3313+
out(variable): The variable to be read from the disk file.
3314+
file_path(str): The path of the disk file.
3315+
"""
3316+
helper = LayerHelper("load", **locals())
3317+
helper.append_op(
3318+
type="load",
3319+
inputs={},
3320+
output={"Out": out},
3321+
args={"file_path": file_path})
3322+
3323+
3324+
def load_combine(out, file_path):
3325+
"""
3326+
Args:
3327+
out(list): The list of variables to be read from the disk file.
3328+
file_path(str): The path of the disk file.
3329+
"""
3330+
helper = LayerHelper("load_combine", **locals())
3331+
helper.append_op(
3332+
type="load_combine",
3333+
inputs={},
3334+
output={"Out": out},
3335+
args={"file_path": file_path})

0 commit comments

Comments
 (0)