11"""PINA Callbacks Implementations"""
22
3+ import importlib .metadata
34import torch
45from lightning .pytorch .callbacks import Callback
56from ..label_tensor import LabelTensor
67from ..utils import check_consistency
78
89
910class R3Refinement (Callback ):
11+ """
12+ PINA Implementation of an R3 Refinement Callback.
13+ """
1014
1115 def __init__ (self , sample_every ):
1216 """
13- PINA Implementation of an R3 Refinement Callback.
14-
1517 This callback implements the R3 (Retain-Resample-Release) routine for
1618 sampling new points based on adaptive search.
1719 The algorithm incrementally accumulates collocation points in regions
18- of high PDE residuals, and releases those
19- with low residuals. Points are sampled uniformly in all regions
20- where sampling is needed.
20+ of high PDE residuals, and releases those with low residuals.
21+ Points are sampled uniformly in all regions where sampling is needed.
2122
2223 .. seealso::
2324
@@ -33,142 +34,148 @@ def __init__(self, sample_every):
3334 Example:
3435 >>> r3_callback = R3Refinement(sample_every=5)
3536 """
36- super ().__init__ ()
37-
38- # sample every
39- check_consistency (sample_every , int )
40- self ._sample_every = sample_every
41- self ._const_pts = None
42-
43- def _compute_residual (self , trainer ):
44- """
45- Computes the residuals for a PINN object.
46-
47- :return: the total loss, and pointwise loss.
48- :rtype: tuple
49- """
50-
51- # extract the solver and device from trainer
52- solver = trainer .solver
53- device = trainer ._accelerator_connector ._accelerator_flag
54- precision = trainer .precision
55- if precision == "64-true" :
56- precision = torch .float64
57- elif precision == "32-true" :
58- precision = torch .float32
59- else :
60- raise RuntimeError (
61- "Currently R3Refinement is only implemented "
62- "for precision '32-true' and '64-true', set "
63- "Trainer precision to match one of the "
64- "available precisions."
65- )
66-
67- # compute residual
68- res_loss = {}
69- tot_loss = []
70- for location in self ._sampling_locations : # TODO fix for new collector
71- condition = solver .problem .conditions [location ]
72- pts = solver .problem .input_pts [location ]
73- # send points to correct device
74- pts = pts .to (device = device , dtype = precision )
75- pts = pts .requires_grad_ (True )
76- pts .retain_grad ()
77- # PINN loss: equation evaluated only for sampling locations
78- target = condition .equation .residual (pts , solver .forward (pts ))
79- res_loss [location ] = torch .abs (target ).as_subclass (torch .Tensor )
80- tot_loss .append (torch .abs (target ))
81-
82- print (tot_loss )
83-
84- return torch .vstack (tot_loss ), res_loss
85-
86- def _r3_routine (self , trainer ):
87- """
88- R3 refinement main routine.
89-
90- :param Trainer trainer: PINA Trainer.
91- """
92- # compute residual (all device possible)
93- tot_loss , res_loss = self ._compute_residual (trainer )
94- tot_loss = tot_loss .as_subclass (torch .Tensor )
95-
96- # !!!!!! From now everything is performed on CPU !!!!!!
97-
98- # average loss
99- avg = (tot_loss .mean ()).to ("cpu" )
100- old_pts = {} # points to be retained
101- for location in self ._sampling_locations :
102- pts = trainer ._model .problem .input_pts [location ]
103- labels = pts .labels
104- pts = pts .cpu ().detach ().as_subclass (torch .Tensor )
105- residuals = res_loss [location ].cpu ()
106- mask = (residuals > avg ).flatten ()
107- if any (mask ): # append residuals greater than average
108- pts = (pts [mask ]).as_subclass (LabelTensor )
109- pts .labels = labels
110- old_pts [location ] = pts
111- numb_pts = self ._const_pts [location ] - len (old_pts [location ])
112- # sample new points
113- trainer ._model .problem .discretise_domain (
114- numb_pts , "random" , locations = [location ]
115- )
116-
117- else : # if no res greater than average, samples all uniformly
118- numb_pts = self ._const_pts [location ]
119- # sample new points
120- trainer ._model .problem .discretise_domain (
121- numb_pts , "random" , locations = [location ]
122- )
123- # adding previous population points
124- trainer ._model .problem .add_points (old_pts )
125-
126- # update dataloader
127- trainer ._create_or_update_loader ()
128-
129- def on_train_start (self , trainer , _ ):
130- """
131- Callback function called at the start of training.
132-
133- This method extracts the locations for sampling from the problem
134- conditions and calculates the total population.
135-
136- :param trainer: The trainer object managing the training process.
137- :type trainer: pytorch_lightning.Trainer
138- :param _: Placeholder argument (not used).
139-
140- :return: None
141- :rtype: None
142- """
143- # extract locations for sampling
144- problem = trainer .solver .problem
145- locations = []
146- for condition_name in problem .conditions :
147- condition = problem .conditions [condition_name ]
148- if hasattr (condition , "location" ):
149- locations .append (condition_name )
150- self ._sampling_locations = locations
151-
152- # extract total population
153- const_pts = {} # for each location, store the # of pts to keep constant
154- for location in self ._sampling_locations :
155- pts = trainer ._model .problem .input_pts [location ]
156- const_pts [location ] = len (pts )
157- self ._const_pts = const_pts
158-
159- def on_train_epoch_end (self , trainer , __ ):
160- """
161- Callback function called at the end of each training epoch.
162-
163- This method triggers the R3 routine for refinement if the current
164- epoch is a multiple of `_sample_every`.
165-
166- :param trainer: The trainer object managing the training process.
167- :type trainer: pytorch_lightning.Trainer
168- :param __: Placeholder argument (not used).
169-
170- :return: None
171- :rtype: None
172- """
173- if trainer .current_epoch % self ._sample_every == 0 :
174- self ._r3_routine (trainer )
37+ raise NotImplementedError (
38+ "R3Refinement callback is being refactored in the pina "
39+ f"{ importlib .metadata .metadata ('pina-mathlab' )['Version' ]} "
40+ "version. Please use version 0.1 if R3Refinement is required."
41+ )
42+
43+ # super().__init__()
44+
45+ # # sample every
46+ # check_consistency(sample_every, int)
47+ # self._sample_every = sample_every
48+ # self._const_pts = None
49+
50+ # def _compute_residual(self, trainer):
51+ # """
52+ # Computes the residuals for a PINN object.
53+
54+ # :return: the total loss, and pointwise loss.
55+ # :rtype: tuple
56+ # """
57+
58+ # # extract the solver and device from trainer
59+ # solver = trainer.solver
60+ # device = trainer._accelerator_connector._accelerator_flag
61+ # precision = trainer.precision
62+ # if precision == "64-true":
63+ # precision = torch.float64
64+ # elif precision == "32-true":
65+ # precision = torch.float32
66+ # else:
67+ # raise RuntimeError(
68+ # "Currently R3Refinement is only implemented "
69+ # "for precision '32-true' and '64-true', set "
70+ # "Trainer precision to match one of the "
71+ # "available precisions."
72+ # )
73+
74+ # # compute residual
75+ # res_loss = {}
76+ # tot_loss = []
77+ # for location in self._sampling_locations:
78+ # condition = solver.problem.conditions[location]
79+ # pts = solver.problem.input_pts[location]
80+ # # send points to correct device
81+ # pts = pts.to(device=device, dtype=precision)
82+ # pts = pts.requires_grad_(True)
83+ # pts.retain_grad()
84+ # # PINN loss: equation evaluated only for sampling locations
85+ # target = condition.equation.residual(pts, solver.forward(pts))
86+ # res_loss[location] = torch.abs(target).as_subclass(torch.Tensor)
87+ # tot_loss.append(torch.abs(target))
88+
89+ # print(tot_loss)
90+
91+ # return torch.vstack(tot_loss), res_loss
92+
93+ # def _r3_routine(self, trainer):
94+ # """
95+ # R3 refinement main routine.
96+
97+ # :param Trainer trainer: PINA Trainer.
98+ # """
99+ # # compute residual (all device possible)
100+ # tot_loss, res_loss = self._compute_residual(trainer)
101+ # tot_loss = tot_loss.as_subclass(torch.Tensor)
102+
103+ # # !!!!!! From now everything is performed on CPU !!!!!!
104+
105+ # # average loss
106+ # avg = (tot_loss.mean()).to("cpu")
107+ # old_pts = {} # points to be retained
108+ # for location in self._sampling_locations:
109+ # pts = trainer._model.problem.input_pts[location]
110+ # labels = pts.labels
111+ # pts = pts.cpu().detach().as_subclass(torch.Tensor)
112+ # residuals = res_loss[location].cpu()
113+ # mask = (residuals > avg).flatten()
114+ # if any(mask): # append residuals greater than average
115+ # pts = (pts[mask]).as_subclass(LabelTensor)
116+ # pts.labels = labels
117+ # old_pts[location] = pts
118+ # numb_pts = self._const_pts[location] - len(old_pts[location])
119+ # # sample new points
120+ # trainer._model.problem.discretise_domain(
121+ # numb_pts, "random", locations=[location]
122+ # )
123+
124+ # else: # if no res greater than average, samples all uniformly
125+ # numb_pts = self._const_pts[location]
126+ # # sample new points
127+ # trainer._model.problem.discretise_domain(
128+ # numb_pts, "random", locations=[location]
129+ # )
130+ # # adding previous population points
131+ # trainer._model.problem.add_points(old_pts)
132+
133+ # # update dataloader
134+ # trainer._create_or_update_loader()
135+
136+ # def on_train_start(self, trainer, _):
137+ # """
138+ # Callback function called at the start of training.
139+
140+ # This method extracts the locations for sampling from the problem
141+ # conditions and calculates the total population.
142+
143+ # :param trainer: The trainer object managing the training process.
144+ # :type trainer: pytorch_lightning.Trainer
145+ # :param _: Placeholder argument (not used).
146+
147+ # :return: None
148+ # :rtype: None
149+ # """
150+ # # extract locations for sampling
151+ # problem = trainer.solver.problem
152+ # locations = []
153+ # for condition_name in problem.conditions:
154+ # condition = problem.conditions[condition_name]
155+ # if hasattr(condition, "location"):
156+ # locations.append(condition_name)
157+ # self._sampling_locations = locations
158+
159+ # # extract total population
160+ # const_pts = {} # for each location, store the pts to keep constant
161+ # for location in self._sampling_locations:
162+ # pts = trainer._model.problem.input_pts[location]
163+ # const_pts[location] = len(pts)
164+ # self._const_pts = const_pts
165+
166+ # def on_train_epoch_end(self, trainer, __):
167+ # """
168+ # Callback function called at the end of each training epoch.
169+
170+ # This method triggers the R3 routine for refinement if the current
171+ # epoch is a multiple of `_sample_every`.
172+
173+ # :param trainer: The trainer object managing the training process.
174+ # :type trainer: pytorch_lightning.Trainer
175+ # :param __: Placeholder argument (not used).
176+
177+ # :return: None
178+ # :rtype: None
179+ # """
180+ # if trainer.current_epoch % self._sample_every == 0:
181+ # self._r3_routine(trainer)
0 commit comments