|
1 | 1 | """PINA Callbacks Implementations""" |
2 | 2 |
|
3 | 3 | import torch |
| 4 | +import importlib.metadata |
4 | 5 | from lightning.pytorch.callbacks import Callback |
5 | 6 | from ..label_tensor import LabelTensor |
6 | 7 | from ..utils import check_consistency |
@@ -33,142 +34,148 @@ def __init__(self, sample_every): |
33 | 34 | Example: |
34 | 35 | >>> r3_callback = R3Refinement(sample_every=5) |
35 | 36 | """ |
36 | | - super().__init__() |
37 | | - |
38 | | - # sample every |
39 | | - check_consistency(sample_every, int) |
40 | | - self._sample_every = sample_every |
41 | | - self._const_pts = None |
42 | | - |
43 | | - def _compute_residual(self, trainer): |
44 | | - """ |
45 | | - Computes the residuals for a PINN object. |
46 | | -
|
47 | | - :return: the total loss, and pointwise loss. |
48 | | - :rtype: tuple |
49 | | - """ |
50 | | - |
51 | | - # extract the solver and device from trainer |
52 | | - solver = trainer.solver |
53 | | - device = trainer._accelerator_connector._accelerator_flag |
54 | | - precision = trainer.precision |
55 | | - if precision == "64-true": |
56 | | - precision = torch.float64 |
57 | | - elif precision == "32-true": |
58 | | - precision = torch.float32 |
59 | | - else: |
60 | | - raise RuntimeError( |
61 | | - "Currently R3Refinement is only implemented " |
62 | | - "for precision '32-true' and '64-true', set " |
63 | | - "Trainer precision to match one of the " |
64 | | - "available precisions." |
65 | | - ) |
66 | | - |
67 | | - # compute residual |
68 | | - res_loss = {} |
69 | | - tot_loss = [] |
70 | | - for location in self._sampling_locations: # TODO fix for new collector |
71 | | - condition = solver.problem.conditions[location] |
72 | | - pts = solver.problem.input_pts[location] |
73 | | - # send points to correct device |
74 | | - pts = pts.to(device=device, dtype=precision) |
75 | | - pts = pts.requires_grad_(True) |
76 | | - pts.retain_grad() |
77 | | - # PINN loss: equation evaluated only for sampling locations |
78 | | - target = condition.equation.residual(pts, solver.forward(pts)) |
79 | | - res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) |
80 | | - tot_loss.append(torch.abs(target)) |
81 | | - |
82 | | - print(tot_loss) |
83 | | - |
84 | | - return torch.vstack(tot_loss), res_loss |
85 | | - |
86 | | - def _r3_routine(self, trainer): |
87 | | - """ |
88 | | - R3 refinement main routine. |
89 | | -
|
90 | | - :param Trainer trainer: PINA Trainer. |
91 | | - """ |
92 | | - # compute residual (all device possible) |
93 | | - tot_loss, res_loss = self._compute_residual(trainer) |
94 | | - tot_loss = tot_loss.as_subclass(torch.Tensor) |
95 | | - |
96 | | - # !!!!!! From now everything is performed on CPU !!!!!! |
97 | | - |
98 | | - # average loss |
99 | | - avg = (tot_loss.mean()).to("cpu") |
100 | | - old_pts = {} # points to be retained |
101 | | - for location in self._sampling_locations: |
102 | | - pts = trainer._model.problem.input_pts[location] |
103 | | - labels = pts.labels |
104 | | - pts = pts.cpu().detach().as_subclass(torch.Tensor) |
105 | | - residuals = res_loss[location].cpu() |
106 | | - mask = (residuals > avg).flatten() |
107 | | - if any(mask): # append residuals greater than average |
108 | | - pts = (pts[mask]).as_subclass(LabelTensor) |
109 | | - pts.labels = labels |
110 | | - old_pts[location] = pts |
111 | | - numb_pts = self._const_pts[location] - len(old_pts[location]) |
112 | | - # sample new points |
113 | | - trainer._model.problem.discretise_domain( |
114 | | - numb_pts, "random", locations=[location] |
115 | | - ) |
116 | | - |
117 | | - else: # if no res greater than average, samples all uniformly |
118 | | - numb_pts = self._const_pts[location] |
119 | | - # sample new points |
120 | | - trainer._model.problem.discretise_domain( |
121 | | - numb_pts, "random", locations=[location] |
122 | | - ) |
123 | | - # adding previous population points |
124 | | - trainer._model.problem.add_points(old_pts) |
125 | | - |
126 | | - # update dataloader |
127 | | - trainer._create_or_update_loader() |
128 | | - |
129 | | - def on_train_start(self, trainer, _): |
130 | | - """ |
131 | | - Callback function called at the start of training. |
132 | | -
|
133 | | - This method extracts the locations for sampling from the problem |
134 | | - conditions and calculates the total population. |
135 | | -
|
136 | | - :param trainer: The trainer object managing the training process. |
137 | | - :type trainer: pytorch_lightning.Trainer |
138 | | - :param _: Placeholder argument (not used). |
139 | | -
|
140 | | - :return: None |
141 | | - :rtype: None |
142 | | - """ |
143 | | - # extract locations for sampling |
144 | | - problem = trainer.solver.problem |
145 | | - locations = [] |
146 | | - for condition_name in problem.conditions: |
147 | | - condition = problem.conditions[condition_name] |
148 | | - if hasattr(condition, "location"): |
149 | | - locations.append(condition_name) |
150 | | - self._sampling_locations = locations |
151 | | - |
152 | | - # extract total population |
153 | | - const_pts = {} # for each location, store the # of pts to keep constant |
154 | | - for location in self._sampling_locations: |
155 | | - pts = trainer._model.problem.input_pts[location] |
156 | | - const_pts[location] = len(pts) |
157 | | - self._const_pts = const_pts |
158 | | - |
159 | | - def on_train_epoch_end(self, trainer, __): |
160 | | - """ |
161 | | - Callback function called at the end of each training epoch. |
162 | | -
|
163 | | - This method triggers the R3 routine for refinement if the current |
164 | | - epoch is a multiple of `_sample_every`. |
165 | | -
|
166 | | - :param trainer: The trainer object managing the training process. |
167 | | - :type trainer: pytorch_lightning.Trainer |
168 | | - :param __: Placeholder argument (not used). |
169 | | -
|
170 | | - :return: None |
171 | | - :rtype: None |
172 | | - """ |
173 | | - if trainer.current_epoch % self._sample_every == 0: |
174 | | - self._r3_routine(trainer) |
| 37 | + raise NotImplementedError( |
| 38 | + "R3Refinement callbacks is being refactoring in the " |
| 39 | + f"pina {importlib.metadata.metadata('pina-mathlab')['Verison']} " |
| 40 | + "version. Please use version 0.1 if R3Refinement is needed." |
| 41 | + ) |
| 42 | + |
| 43 | + # super().__init__() |
| 44 | + |
| 45 | + # # sample every |
| 46 | + # check_consistency(sample_every, int) |
| 47 | + # self._sample_every = sample_every |
| 48 | + # self._const_pts = None |
| 49 | + |
| 50 | + # def _compute_residual(self, trainer): |
| 51 | + # """ |
| 52 | + # Computes the residuals for a PINN object. |
| 53 | + |
| 54 | + # :return: the total loss, and pointwise loss. |
| 55 | + # :rtype: tuple |
| 56 | + # """ |
| 57 | + |
| 58 | + # # extract the solver and device from trainer |
| 59 | + # solver = trainer.solver |
| 60 | + # device = trainer._accelerator_connector._accelerator_flag |
| 61 | + # precision = trainer.precision |
| 62 | + # if precision == "64-true": |
| 63 | + # precision = torch.float64 |
| 64 | + # elif precision == "32-true": |
| 65 | + # precision = torch.float32 |
| 66 | + # else: |
| 67 | + # raise RuntimeError( |
| 68 | + # "Currently R3Refinement is only implemented " |
| 69 | + # "for precision '32-true' and '64-true', set " |
| 70 | + # "Trainer precision to match one of the " |
| 71 | + # "available precisions." |
| 72 | + # ) |
| 73 | + |
| 74 | + # # compute residual |
| 75 | + # res_loss = {} |
| 76 | + # tot_loss = [] |
| 77 | + # for location in self._sampling_locations: # TODO fix for new collector |
| 78 | + # condition = solver.problem.conditions[location] |
| 79 | + # pts = solver.problem.input_pts[location] |
| 80 | + # # send points to correct device |
| 81 | + # pts = pts.to(device=device, dtype=precision) |
| 82 | + # pts = pts.requires_grad_(True) |
| 83 | + # pts.retain_grad() |
| 84 | + # # PINN loss: equation evaluated only for sampling locations |
| 85 | + # target = condition.equation.residual(pts, solver.forward(pts)) |
| 86 | + # res_loss[location] = torch.abs(target).as_subclass(torch.Tensor) |
| 87 | + # tot_loss.append(torch.abs(target)) |
| 88 | + |
| 89 | + # print(tot_loss) |
| 90 | + |
| 91 | + # return torch.vstack(tot_loss), res_loss |
| 92 | + |
| 93 | + # def _r3_routine(self, trainer): |
| 94 | + # """ |
| 95 | + # R3 refinement main routine. |
| 96 | + |
| 97 | + # :param Trainer trainer: PINA Trainer. |
| 98 | + # """ |
| 99 | + # # compute residual (all device possible) |
| 100 | + # tot_loss, res_loss = self._compute_residual(trainer) |
| 101 | + # tot_loss = tot_loss.as_subclass(torch.Tensor) |
| 102 | + |
| 103 | + # # !!!!!! From now everything is performed on CPU !!!!!! |
| 104 | + |
| 105 | + # # average loss |
| 106 | + # avg = (tot_loss.mean()).to("cpu") |
| 107 | + # old_pts = {} # points to be retained |
| 108 | + # for location in self._sampling_locations: |
| 109 | + # pts = trainer._model.problem.input_pts[location] |
| 110 | + # labels = pts.labels |
| 111 | + # pts = pts.cpu().detach().as_subclass(torch.Tensor) |
| 112 | + # residuals = res_loss[location].cpu() |
| 113 | + # mask = (residuals > avg).flatten() |
| 114 | + # if any(mask): # append residuals greater than average |
| 115 | + # pts = (pts[mask]).as_subclass(LabelTensor) |
| 116 | + # pts.labels = labels |
| 117 | + # old_pts[location] = pts |
| 118 | + # numb_pts = self._const_pts[location] - len(old_pts[location]) |
| 119 | + # # sample new points |
| 120 | + # trainer._model.problem.discretise_domain( |
| 121 | + # numb_pts, "random", locations=[location] |
| 122 | + # ) |
| 123 | + |
| 124 | + # else: # if no res greater than average, samples all uniformly |
| 125 | + # numb_pts = self._const_pts[location] |
| 126 | + # # sample new points |
| 127 | + # trainer._model.problem.discretise_domain( |
| 128 | + # numb_pts, "random", locations=[location] |
| 129 | + # ) |
| 130 | + # # adding previous population points |
| 131 | + # trainer._model.problem.add_points(old_pts) |
| 132 | + |
| 133 | + # # update dataloader |
| 134 | + # trainer._create_or_update_loader() |
| 135 | + |
| 136 | + # def on_train_start(self, trainer, _): |
| 137 | + # """ |
| 138 | + # Callback function called at the start of training. |
| 139 | + |
| 140 | + # This method extracts the locations for sampling from the problem |
| 141 | + # conditions and calculates the total population. |
| 142 | + |
| 143 | + # :param trainer: The trainer object managing the training process. |
| 144 | + # :type trainer: pytorch_lightning.Trainer |
| 145 | + # :param _: Placeholder argument (not used). |
| 146 | + |
| 147 | + # :return: None |
| 148 | + # :rtype: None |
| 149 | + # """ |
| 150 | + # # extract locations for sampling |
| 151 | + # problem = trainer.solver.problem |
| 152 | + # locations = [] |
| 153 | + # for condition_name in problem.conditions: |
| 154 | + # condition = problem.conditions[condition_name] |
| 155 | + # if hasattr(condition, "location"): |
| 156 | + # locations.append(condition_name) |
| 157 | + # self._sampling_locations = locations |
| 158 | + |
| 159 | + # # extract total population |
| 160 | + # const_pts = {} # for each location, store the # of pts to keep constant |
| 161 | + # for location in self._sampling_locations: |
| 162 | + # pts = trainer._model.problem.input_pts[location] |
| 163 | + # const_pts[location] = len(pts) |
| 164 | + # self._const_pts = const_pts |
| 165 | + |
| 166 | + # def on_train_epoch_end(self, trainer, __): |
| 167 | + # """ |
| 168 | + # Callback function called at the end of each training epoch. |
| 169 | + |
| 170 | + # This method triggers the R3 routine for refinement if the current |
| 171 | + # epoch is a multiple of `_sample_every`. |
| 172 | + |
| 173 | + # :param trainer: The trainer object managing the training process. |
| 174 | + # :type trainer: pytorch_lightning.Trainer |
| 175 | + # :param __: Placeholder argument (not used). |
| 176 | + |
| 177 | + # :return: None |
| 178 | + # :rtype: None |
| 179 | + # """ |
| 180 | + # if trainer.current_epoch % self._sample_every == 0: |
| 181 | + # self._r3_routine(trainer) |
0 commit comments