11"""
2- Saving and loading a general checkpoint in PyTorch
3- ==================================================
4- Saving and loading a general checkpoint model for inference or
5- resuming training can be helpful for picking up where you last left off.
6- When saving a general checkpoint, you must save more than just the
7- modelโs state_dict. It is important to also save the optimizerโs
8- state_dict, as this contains buffers and parameters that are updated as
9- the model trains. Other items that you may want to save are the epoch
10- you left off on, the latest recorded training loss, external
11- ``torch.nn.Embedding`` layers, and more, based on your own algorithm.
12-
13- Introduction
2+ PyTorch์์ ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ(checkpoint) ์ ์ฅํ๊ธฐ & ๋ถ๋ฌ์ค๊ธฐ
3+ ===================================================================
4+ ์ถ๋ก (inference) ๋๋ ํ์ต(training)์ ์ฌ๊ฐ๋ฅผ ์ํด ์ฒดํฌํฌ์ธํธ(checkpoint) ๋ชจ๋ธ์
5+ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๋ ๊ฒ์ ๋ง์ง๋ง์ผ๋ก ์ค๋จํ๋ ๋ถ๋ถ์ ์ ํํ๋๋ฐ ๋์์ ์ค ์ ์์ต๋๋ค.
6+ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ ๋๋ ๋จ์ํ ๋ชจ๋ธ์ state_dict ์ด์์ ๊ฒ์ ์ ์ฅํด์ผ ํฉ๋๋ค.
7+ ๋ชจ๋ธ ํ์ต ์ค์ ๊ฐฑ์ ๋๋ ํผ๋ฒ์ ๋งค๊ฐ๋ณ์๋ค์ ํฌํจํ๋ ์ตํฐ๋ง์ด์ (Optimizer)์
8+ state_dict๋ฅผ ํจ๊ป ์ ์ฅํ๋ ๊ฒ์ด ์ค์ํฉ๋๋ค. ์ด ์ธ์๋ ์ค๋จ ์์ ์ ์ํฌํฌ(epoch),
9+ ๋ง์ง๋ง์ผ๋ก ๊ธฐ๋ก๋ ํ์ต ์ค์ฐจ(training loss), ์ธ๋ถ ``torch.nn.Embedding`` ๊ณ์ธต ๋ฑ,
10+ ์๊ณ ๋ฆฌ์ฆ์ ๋ฐ๋ผ ์ ์ฅํ๊ณ ์ถ์ ํญ๋ชฉ๋ค์ด ์์ ๊ฒ์
๋๋ค.
11+
12+ ๊ฐ์
1413------------
15- To save multiple checkpoints, you must organize them in a dictionary and
16- use ``torch.save()`` to serialize the dictionary. A common PyTorch
17- convention is to save these checkpoints using the ``.tar`` file
18- extension. To load the items, first initialize the model and optimizer,
19- then load the dictionary locally using torch.load(). From here, you can
20- easily access the saved items by simply querying the dictionary as you
21- would expect.
22-
23- In this recipe, we will explore how to save and load multiple
24- checkpoints.
25-
26- Setup
14+ ์ฌ๋ฌ ์ฒดํฌํฌ์ธํธ๋ค์ ์ ์ฅํ๊ธฐ ์ํด์๋ ์ฌ์ (dictionary)์ ์ฒดํฌํฌ์ธํธ๋ค์ ๊ตฌ์ฑํ๊ณ
15+ ``torch.save()`` ๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ ์ ์ง๋ ฌํ(serialize)ํด์ผ ํฉ๋๋ค. ์ผ๋ฐ์ ์ธ
16+ PyTorch์์๋ ์ด๋ฌํ ์ฌ๋ฌ ์ฒดํฌํฌ์ธํธ๋ค์ ์ ์ฅํ ๋ ``.tar`` ํ์ฅ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด
17+ ์ผ๋ฐ์ ์ธ ๊ท์น์
๋๋ค. ํญ๋ชฉ๋ค์ ๋ถ๋ฌ์ฌ ๋์๋, ๋จผ์ ๋ชจ๋ธ๊ณผ ์ตํฐ๋ง์ด์ ๋ฅผ ์ด๊ธฐํํ๊ณ ,
18+ torch.load()๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ ์ ๋ถ๋ฌ์ต๋๋ค. ์ดํ ์ํ๋๋๋ก ์ ์ฅํ ํญ๋ชฉ๋ค์ ์ฌ์ ์
19+ ์กฐํํ์ฌ ์ ๊ทผํ ์ ์์ต๋๋ค.
20+
21+ ์ด ๋ ์ํผ์์๋ ์ฌ๋ฌ ์ฒดํฌํฌ์ธํธ๋ค์ ์ด๋ป๊ฒ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๋์ง ์ดํด๋ณด๊ฒ ์ต๋๋ค.
22+
23+ ์ค์
2724-----
28- Before we begin, we need to install ``torch`` if it isnโt already
29- available.
25+ ์์ํ๊ธฐ ์ ์ ``torch`` ๊ฐ ์๋ค๋ฉด ์ค์นํด์ผ ํฉ๋๋ค.
26+
3027
3128::
3229
3835
3936
4037######################################################################
41- # Steps
42- # -----
43- #
44- # 1. Import all necessary libraries for loading our data
45- # 2. Define and intialize the neural network
46- # 3. Initialize the optimizer
47- # 4. Save the general checkpoint
48- # 5. Load the general checkpoint
49- #
50- # 1. Import necessary libraries for loading our data
38+ # ๋จ๊ณ( Steps)
39+ # ------------
40+ #
41+ # 1. ๋ฐ์ดํฐ ๋ถ๋ฌ์ฌ ๋ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค ๋ถ๋ฌ์ค๊ธฐ
42+ # 2. ์ ๊ฒฝ๋ง์ ๊ตฌ์ฑํ๊ณ ์ด๊ธฐํํ๊ธฐ
43+ # 3. ์ตํฐ๋ง์ด์ ์ด๊ธฐํํ๊ธฐ
44+ # 4. ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ ์ ์ฅํ๊ธฐ
45+ # 5. ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ
46+ #
47+ # 1. ๋ฐ์ดํฐ ๋ถ๋ฌ์ฌ ๋ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค ๋ถ๋ฌ์ค๊ธฐ
5148# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
52- #
53- # For this recipe, we will use ``torch`` and its subsidiaries ``torch.nn``
54- # and ``torch.optim`` .
55- #
49+ #
50+ # ์ด ๋ ์ํผ์์๋ ``torch`` ์ ์ฌ๊ธฐ ํฌํจ๋ ``torch.nn`` ์ ``torch.optim` ์
51+ # ์ฌ์ฉํ๊ฒ ์ต๋๋ค .
52+ #
5653
5754import torch
5855import torch .nn as nn
5956import torch .optim as optim
6057
6158
6259######################################################################
63- # 2. Define and intialize the neural network
60+ # 2. ์ ๊ฒฝ๋ง์ ๊ตฌ์ฑํ๊ณ ์ด๊ธฐํํ๊ธฐ
6461# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
65- #
66- # For sake of example, we will create a neural network for training
67- # images. To learn more see the Defining a Neural Network recipe .
68- #
62+ #
63+ # ์๋ฅผ ๋ค์ด, ์ด๋ฏธ์ง๋ฅผ ํ์ตํ๋ ์ ๊ฒฝ๋ง์ ๋ง๋ค์ด๋ณด๊ฒ ์ต๋๋ค. ๋ ์์ธํ ๋ด์ฉ์
64+ # ์ ๊ฒฝ๋ง ๊ตฌ์ฑํ๊ธฐ ๋ ์ํผ๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์ .
65+ #
6966
7067class Net (nn .Module ):
7168 def __init__ (self ):
@@ -91,23 +88,23 @@ def forward(self, x):
9188
9289
9390######################################################################
94- # 3. Initialize the optimizer
91+ # 3. ์ตํฐ๋ง์ด์ ์ด๊ธฐํํ๊ธฐ
9592# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
96- #
97- # We will use SGD with momentum .
98- #
93+ #
94+ # ๋ชจ๋ฉํ
(momentum)์ ๊ฐ๋ SGD๋ฅผ ์ฌ์ฉํ๊ฒ ์ต๋๋ค .
95+ #
9996
10097optimizer = optim .SGD (net .parameters (), lr = 0.001 , momentum = 0.9 )
10198
10299
103100######################################################################
104- # 4. Save the general checkpoint
101+ # 4. ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ ์ ์ฅํ๊ธฐ
105102# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
106- #
107- # Collect all relevant information and build your dictionary .
108- #
103+ #
104+ # ๊ด๋ จ๋ ๋ชจ๋ ์ ๋ณด๋ค์ ๋ชจ์์ ์ฌ์ ์ ๊ตฌ์ฑํฉ๋๋ค .
105+ #
109106
110- # Additional information
107+ # ์ถ๊ฐ ์ ๋ณด
111108EPOCH = 5
112109PATH = "model.pt"
113110LOSS = 0.4
@@ -121,12 +118,11 @@ def forward(self, x):
121118
122119
123120######################################################################
124- # 5. Load the general checkpoint
121+ # 5. ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ ๋ถ๋ฌ์ค๊ธฐ
125122# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
126- #
127- # Remember to first initialize the model and optimizer, then load the
128- # dictionary locally.
129- #
123+ #
124+ # ๋จผ์ ๋ชจ๋ธ๊ณผ ์ตํฐ๋ง์ด์ ๋ฅผ ์ด๊ธฐํํ ๋ค, ์ฌ์ ์ ๋ถ๋ฌ์ค๋ ๊ฒ์ ๊ธฐ์ตํ์ญ์์ค.
125+ #
130126
131127model = Net ()
132128optimizer = optim .SGD (net .parameters (), lr = 0.001 , momentum = 0.9 )
@@ -138,25 +134,25 @@ def forward(self, x):
138134loss = checkpoint ['loss' ]
139135
140136model .eval ()
141- # - or -
137+ # - ๋๋ -
142138model .train ()
143139
144140
145141######################################################################
146- # You must call ``model.eval()`` to set dropout and batch normalization
147- # layers to evaluation mode before running inference. Failing to do this
148- # will yield inconsistent inference results .
149- #
150- # If you wish to resuming training, call ``model.train()`` to ensure these
151- # layers are in training mode .
152- #
153- # Congratulations! You have successfully saved and loaded a general
154- # checkpoint for inference and/or resuming training in PyTorch .
155- #
156- # Learn More
157- # ----------
158- #
159- # Take a look at these other recipes to continue your learning :
160- #
161- # - TBD
162- # - TBD
142+ # ์ถ๋ก (inference)์ ์คํํ๊ธฐ ์ ์ ``model.eval()`` ์ ํธ์ถํ์ฌ ๋๋กญ์์( dropout)๊ณผ
143+ # ๋ฐฐ์น ์ ๊ทํ ์ธต(batch normalization layer)์ ํ๊ฐ(evaluation) ๋ชจ๋๋ก ๋ฐ๊ฟ์ผํ๋ค๋
144+ # ๊ฒ์ ๊ธฐ์ตํ์ธ์. ์ด๊ฒ์ ๋นผ๋จน์ผ๋ฉด ์ผ๊ด์ฑ ์๋ ์ถ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ป๊ฒ ๋ฉ๋๋ค .
145+ #
146+ # ๋ง์ฝ ํ์ต์ ๊ณ์ํ๊ธธ ์ํ๋ค๋ฉด ``model.train()`` ์ ํธ์ถํ์ฌ ์ด ์ธต(layer)๋ค์ด
147+ # ํ์ต ๋ชจ๋์ธ์ง ํ์ธ(ensure)ํ์ธ์ .
148+ #
149+ # ์ถํํฉ๋๋ค! ์ง๊ธ๊น์ง PyTorch์์ ์ถ๋ก ๋๋ ํ์ต ์ฌ๊ฐ๋ฅผ ์ํด ์ผ๋ฐ์ ์ธ ์ฒดํฌํฌ์ธํธ๋ฅผ
150+ # ์ ์ฅํ๊ณ ๋ถ๋ฌ์์ต๋๋ค .
151+ #
152+ # ๋ ์์๋ณด๊ธฐ
153+ # ------------
154+ #
155+ # ๋ค๋ฅธ ๋ ์ํผ๋ฅผ ๋๋ฌ๋ณด๊ณ ๊ณ์ ๋ฐฐ์๋ณด์ธ์ :
156+ #
157+ # - :doc:`/recipes/recipes/saving_and_loading_a_general_checkpoint`
158+ # - :doc:`/recipes/recipes/saving_multiple_models_in_one_file`
0 commit comments