Skip to content

Commit d759b21

Browse files
authored
Merge pull request #74 from 9bow/recipes_source/recipes/saving_and_loading_a_general_checkpoint
'PyTorch์—์„œ ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ(checkpoint) ์ €์žฅํ•˜๊ธฐ & ๋ถˆ๋Ÿฌ์˜ค๊ธฐ' ๋ฒˆ์—ญ
2 parents f8a08d8 + 6d61014 commit d759b21

File tree

1 file changed

+73
-77
lines changed

1 file changed

+73
-77
lines changed
Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,29 @@
11
"""
2-
Saving and loading a general checkpoint in PyTorch
3-
==================================================
4-
Saving and loading a general checkpoint model for inference or
5-
resuming training can be helpful for picking up where you last left off.
6-
When saving a general checkpoint, you must save more than just the
7-
modelโ€™s state_dict. It is important to also save the optimizerโ€™s
8-
state_dict, as this contains buffers and parameters that are updated as
9-
the model trains. Other items that you may want to save are the epoch
10-
you left off on, the latest recorded training loss, external
11-
``torch.nn.Embedding`` layers, and more, based on your own algorithm.
12-
13-
Introduction
2+
PyTorch์—์„œ ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ(checkpoint) ์ €์žฅํ•˜๊ธฐ & ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
3+
===================================================================
4+
์ถ”๋ก (inference) ๋˜๋Š” ํ•™์Šต(training)์˜ ์žฌ๊ฐœ๋ฅผ ์œ„ํ•ด ์ฒดํฌํฌ์ธํŠธ(checkpoint) ๋ชจ๋ธ์„
5+
์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒƒ์€ ๋งˆ์ง€๋ง‰์œผ๋กœ ์ค‘๋‹จํ–ˆ๋˜ ๋ถ€๋ถ„์„ ์„ ํƒํ•˜๋Š”๋ฐ ๋„์›€์„ ์ค„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
6+
์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•  ๋•Œ๋Š” ๋‹จ์ˆœํžˆ ๋ชจ๋ธ์˜ state_dict ์ด์ƒ์˜ ๊ฒƒ์„ ์ €์žฅํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
7+
๋ชจ๋ธ ํ•™์Šต ์ค‘์— ๊ฐฑ์‹ ๋˜๋Š” ํผ๋ฒ„์™€ ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์„ ํฌํ•จํ•˜๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €(Optimizer)์˜
8+
state_dict๋ฅผ ํ•จ๊ป˜ ์ €์žฅํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด ์™ธ์—๋„ ์ค‘๋‹จ ์‹œ์ ์˜ ์—ํฌํฌ(epoch),
9+
๋งˆ์ง€๋ง‰์œผ๋กœ ๊ธฐ๋ก๋œ ํ•™์Šต ์˜ค์ฐจ(training loss), ์™ธ๋ถ€ ``torch.nn.Embedding`` ๊ณ„์ธต ๋“ฑ,
10+
์•Œ๊ณ ๋ฆฌ์ฆ˜์— ๋”ฐ๋ผ ์ €์žฅํ•˜๊ณ  ์‹ถ์€ ํ•ญ๋ชฉ๋“ค์ด ์žˆ์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
11+
12+
๊ฐœ์š”
1413
------------
15-
To save multiple checkpoints, you must organize them in a dictionary and
16-
use ``torch.save()`` to serialize the dictionary. A common PyTorch
17-
convention is to save these checkpoints using the ``.tar`` file
18-
extension. To load the items, first initialize the model and optimizer,
19-
then load the dictionary locally using torch.load(). From here, you can
20-
easily access the saved items by simply querying the dictionary as you
21-
would expect.
22-
23-
In this recipe, we will explore how to save and load multiple
24-
checkpoints.
25-
26-
Setup
14+
์—ฌ๋Ÿฌ ์ฒดํฌํฌ์ธํŠธ๋“ค์„ ์ €์žฅํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์‚ฌ์ „(dictionary)์— ์ฒดํฌํฌ์ธํŠธ๋“ค์„ ๊ตฌ์„ฑํ•˜๊ณ 
15+
``torch.save()`` ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์ „์„ ์ง๋ ฌํ™”(serialize)ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ
16+
PyTorch์—์„œ๋Š” ์ด๋Ÿฌํ•œ ์—ฌ๋Ÿฌ ์ฒดํฌํฌ์ธํŠธ๋“ค์„ ์ €์žฅํ•  ๋•Œ ``.tar`` ํ™•์žฅ์ž๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด
17+
์ผ๋ฐ˜์ ์ธ ๊ทœ์น™์ž…๋‹ˆ๋‹ค. ํ•ญ๋ชฉ๋“ค์„ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ์—๋Š”, ๋จผ์ € ๋ชจ๋ธ๊ณผ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๊ณ ,
18+
torch.load()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์ „์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค. ์ดํ›„ ์›ํ•˜๋Š”๋Œ€๋กœ ์ €์žฅํ•œ ํ•ญ๋ชฉ๋“ค์„ ์‚ฌ์ „์—
19+
์กฐํšŒํ•˜์—ฌ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
20+
21+
์ด ๋ ˆ์‹œํ”ผ์—์„œ๋Š” ์—ฌ๋Ÿฌ ์ฒดํฌํฌ์ธํŠธ๋“ค์„ ์–ด๋–ป๊ฒŒ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๋Š”์ง€ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
22+
23+
์„ค์ •
2724
-----
28-
Before we begin, we need to install ``torch`` if it isnโ€™t already
29-
available.
25+
์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ``torch`` ๊ฐ€ ์—†๋‹ค๋ฉด ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
26+
3027
3128
::
3229
@@ -38,34 +35,34 @@
3835

3936

4037
######################################################################
41-
# Steps
42-
# -----
43-
#
44-
# 1. Import all necessary libraries for loading our data
45-
# 2. Define and intialize the neural network
46-
# 3. Initialize the optimizer
47-
# 4. Save the general checkpoint
48-
# 5. Load the general checkpoint
49-
#
50-
# 1. Import necessary libraries for loading our data
38+
# ๋‹จ๊ณ„(Steps)
39+
# ------------
40+
#
41+
# 1. ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
42+
# 2. ์‹ ๊ฒฝ๋ง์„ ๊ตฌ์„ฑํ•˜๊ณ  ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ
43+
# 3. ์˜ตํ‹ฐ๋งˆ์ด์ € ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ
44+
# 4. ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅํ•˜๊ธฐ
45+
# 5. ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
46+
#
47+
# 1. ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
5148
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
52-
#
53-
# For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``
54-
# and ``torch.optim``.
55-
#
49+
#
50+
# ์ด ๋ ˆ์‹œํ”ผ์—์„œ๋Š” ``torch`` ์™€ ์—ฌ๊ธฐ ํฌํ•จ๋œ ``torch.nn`` ์™€ ``torch.optim` ์„
51+
# ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
52+
#
5653

5754
import torch
5855
import torch.nn as nn
5956
import torch.optim as optim
6057

6158

6259
######################################################################
63-
# 2. Define and intialize the neural network
60+
# 2. ์‹ ๊ฒฝ๋ง์„ ๊ตฌ์„ฑํ•˜๊ณ  ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ
6461
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
65-
#
66-
# For sake of example, we will create a neural network for training
67-
# images. To learn more see the Defining a Neural Network recipe.
68-
#
62+
#
63+
# ์˜ˆ๋ฅผ ๋“ค์–ด, ์ด๋ฏธ์ง€๋ฅผ ํ•™์Šตํ•˜๋Š” ์‹ ๊ฒฝ๋ง์„ ๋งŒ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋” ์ž์„ธํ•œ ๋‚ด์šฉ์€
64+
# ์‹ ๊ฒฝ๋ง ๊ตฌ์„ฑํ•˜๊ธฐ ๋ ˆ์‹œํ”ผ๋ฅผ ์ฐธ๊ณ ํ•ด์ฃผ์„ธ์š”.
65+
#
6966

7067
class Net(nn.Module):
7168
def __init__(self):
@@ -91,23 +88,23 @@ def forward(self, x):
9188

9289

9390
######################################################################
94-
# 3. Initialize the optimizer
91+
# 3. ์˜ตํ‹ฐ๋งˆ์ด์ € ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ
9592
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
96-
#
97-
# We will use SGD with momentum.
98-
#
93+
#
94+
# ๋ชจ๋ฉ˜ํ…€(momentum)์„ ๊ฐ–๋Š” SGD๋ฅผ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
95+
#
9996

10097
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
10198

10299

103100
######################################################################
104-
# 4. Save the general checkpoint
101+
# 4. ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅํ•˜๊ธฐ
105102
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
106-
#
107-
# Collect all relevant information and build your dictionary.
108-
#
103+
#
104+
# ๊ด€๋ จ๋œ ๋ชจ๋“  ์ •๋ณด๋“ค์„ ๋ชจ์•„์„œ ์‚ฌ์ „์„ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.
105+
#
109106

110-
# Additional information
107+
# ์ถ”๊ฐ€ ์ •๋ณด
111108
EPOCH = 5
112109
PATH = "model.pt"
113110
LOSS = 0.4
@@ -121,12 +118,11 @@ def forward(self, x):
121118

122119

123120
######################################################################
124-
# 5. Load the general checkpoint
121+
# 5. ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
125122
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
126-
#
127-
# Remember to first initialize the model and optimizer, then load the
128-
# dictionary locally.
129-
#
123+
#
124+
# ๋จผ์ € ๋ชจ๋ธ๊ณผ ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•œ ๋’ค, ์‚ฌ์ „์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” ๊ฒƒ์„ ๊ธฐ์–ตํ•˜์‹ญ์‹œ์˜ค.
125+
#
130126

131127
model = Net()
132128
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
@@ -138,25 +134,25 @@ def forward(self, x):
138134
loss = checkpoint['loss']
139135

140136
model.eval()
141-
# - or -
137+
# - ๋˜๋Š” -
142138
model.train()
143139

144140

145141
######################################################################
146-
# You must call ``model.eval()`` to set dropout and batch normalization
147-
# layers to evaluation mode before running inference. Failing to do this
148-
# will yield inconsistent inference results.
149-
#
150-
# If you wish to resuming training, call ``model.train()`` to ensure these
151-
# layers are in training mode.
152-
#
153-
# Congratulations! You have successfully saved and loaded a general
154-
# checkpoint for inference and/or resuming training in PyTorch.
155-
#
156-
# Learn More
157-
# ----------
158-
#
159-
# Take a look at these other recipes to continue your learning:
160-
#
161-
# - TBD
162-
# - TBD
142+
# ์ถ”๋ก (inference)์„ ์‹คํ–‰ํ•˜๊ธฐ ์ „์— ``model.eval()`` ์„ ํ˜ธ์ถœํ•˜์—ฌ ๋“œ๋กญ์•„์›ƒ(dropout)๊ณผ
143+
# ๋ฐฐ์น˜ ์ •๊ทœํ™” ์ธต(batch normalization layer)์„ ํ‰๊ฐ€(evaluation) ๋ชจ๋“œ๋กœ ๋ฐ”๊ฟ”์•ผํ•œ๋‹ค๋Š”
144+
# ๊ฒƒ์„ ๊ธฐ์–ตํ•˜์„ธ์š”. ์ด๊ฒƒ์„ ๋นผ๋จน์œผ๋ฉด ์ผ๊ด€์„ฑ ์—†๋Š” ์ถ”๋ก  ๊ฒฐ๊ณผ๋ฅผ ์–ป๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.
145+
#
146+
# ๋งŒ์•ฝ ํ•™์Šต์„ ๊ณ„์†ํ•˜๊ธธ ์›ํ•œ๋‹ค๋ฉด ``model.train()`` ์„ ํ˜ธ์ถœํ•˜์—ฌ ์ด ์ธต(layer)๋“ค์ด
147+
# ํ•™์Šต ๋ชจ๋“œ์ธ์ง€ ํ™•์ธ(ensure)ํ•˜์„ธ์š”.
148+
#
149+
# ์ถ•ํ•˜ํ•ฉ๋‹ˆ๋‹ค! ์ง€๊ธˆ๊นŒ์ง€ PyTorch์—์„œ ์ถ”๋ก  ๋˜๋Š” ํ•™์Šต ์žฌ๊ฐœ๋ฅผ ์œ„ํ•ด ์ผ๋ฐ˜์ ์ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ
150+
# ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์™”์Šต๋‹ˆ๋‹ค.
151+
#
152+
# ๋” ์•Œ์•„๋ณด๊ธฐ
153+
# ------------
154+
#
155+
# ๋‹ค๋ฅธ ๋ ˆ์‹œํ”ผ๋ฅผ ๋‘˜๋Ÿฌ๋ณด๊ณ  ๊ณ„์† ๋ฐฐ์›Œ๋ณด์„ธ์š”:
156+
#
157+
# - :doc:`/recipes/recipes/saving_and_loading_a_general_checkpoint`
158+
# - :doc:`/recipes/recipes/saving_multiple_models_in_one_file`

0 commit comments

Comments
ย (0)