Skip to content

Commit 52679de

Browse files
authored
Merge pull request #2 from Fraunhofer-AISEC/regularization_control
feat: add regularization type (L1, L2) and regularization strength control
2 parents 19db3a7 + 189ed9f commit 52679de

File tree

3 files changed

+58
-4
lines changed

3 files changed

+58
-4
lines changed

app/layout.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,38 @@
6464
multi=False
6565
),
6666
], className="columnContainer"),
67+
html.Div([
68+
html.H3("Regularization Type"),
69+
dcc.Dropdown(
70+
id="select_reg_type",
71+
options=[
72+
{"value": "none", "label": "None"},
73+
{"value": "l1", "label": "L1"},
74+
{"value": "l2", "label": "L2"},
75+
],
76+
value="none",
77+
multi=False
78+
),
79+
], className="columnContainer"),
80+
html.Div([
81+
html.H3("Regularization Strength"),
82+
dcc.Dropdown(
83+
id="select_reg_strength",
84+
options=[
85+
{"value": 0.001, "label": 0.001},
86+
{"value": 0.005, "label": 0.005},
87+
{"value": 0.01, "label": 0.01},
88+
{"value": 0.05, "label": 0.05},
89+
{"value": 0.1, "label": 0.1},
90+
{"value": 0.5, "label": 0.5},
91+
{"value": 1.0, "label": 1.0},
92+
{"value": 5.0, "label": 5.0},
93+
{"value": 10.0, "label": 10.0},
94+
],
95+
value=0.01,
96+
multi=False
97+
),
98+
], className="columnContainer"),
6799
], className="rowContainer longContainer",
68100
style={'justify-content': 'space-around',
69101
'margin': '0px'}

app/logic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,16 @@ def update_circuit_plot(num_qubits: int, num_layers: int, model_parameters: str)
299299
state=[
300300
State(component_id="select_lr", component_property="value"),
301301
State(component_id="select_batch_size", component_property="value"),
302+
State(component_id="select_reg_type", component_property="value"),
303+
State(component_id="select_reg_strength", component_property="value"),
302304
State(component_id="train_datastore", component_property="data"),
303305
State(component_id="model_parameters", component_property="data"),
304306
],
305307
)
306308
def single_epoch(num_clicks: int, num_intervals: int, reset_clicks: int,
307309
num_qubits: int, num_layers: int, selected_data_set: str,
308-
lr: float, batch_size: int, train_data, model_parameters):
310+
lr: float, batch_size: int, reg_type: str, reg_strength: float,
311+
train_data, model_parameters):
309312
"""
310313
Performs a single training epoch for the quantum model using the provided parameters.
311314
@@ -325,6 +328,10 @@ def single_epoch(num_clicks: int, num_intervals: int, reset_clicks: int,
325328
:type lr: float
326329
:param batch_size: Size of training batches
327330
:type batch_size: int
331+
:param reg_type: Type of regularization (none, l1, l2)
332+
:type reg_type: str
333+
:param reg_strength: Strength of regularization
334+
:type reg_strength: float
328335
:param train_data: Training data in JSON format
329336
:param model_parameters: Current model parameters
330337
:return: Updated model parameters and current epoch number
@@ -345,7 +352,7 @@ def single_epoch(num_clicks: int, num_intervals: int, reset_clicks: int,
345352
model_parameters = unserialize_model_dict(model_parameters)
346353
qcl.load_model(model_parameters)
347354

348-
qcl.train_single_epoch(df_train[["x", "y"]].values, df_train["label"].values, lr, batch_size)
355+
qcl.train_single_epoch(df_train[["x", "y"]].values, df_train["label"].values, lr, batch_size, reg_type, reg_strength)
349356
model_parameters = qcl.save_model()
350357

351358
return [json.dumps(model_parameters), model_parameters["config"]["epoch"]]

app/models/reuploading_classifier.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def evaluate(self, X, y):
262262

263263
return results
264264

265-
def train_single_epoch(self, X, y, lr=0.1, batch_size=32):
265+
def train_single_epoch(self, X, y, lr=0.1, batch_size=32, reg_type="none", reg_strength=0.01):
266266
"""Train the quantum model for a single epoch.
267267
268268
This method performs one epoch of training using mini-batch gradient descent.
@@ -274,16 +274,31 @@ def train_single_epoch(self, X, y, lr=0.1, batch_size=32):
274274
y (numpy.ndarray): Training labels of shape (n_samples,)
275275
lr (float, optional): Learning rate for the optimizer. Defaults to 0.1.
276276
batch_size (int, optional): Size of mini-batches for training. Defaults to 32.
277+
reg_type (str, optional): Type of regularization ('none', 'l1', 'l2'). Defaults to "none".
278+
reg_strength (float, optional): Strength of regularization. Defaults to 0.01.
277279
278280
Returns:
279281
None
280282
"""
281-
opt = Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999))
283+
# Configure weight decay (L2 regularization) if selected
284+
weight_decay = 0.0
285+
if reg_type == "l2":
286+
weight_decay = reg_strength
287+
288+
opt = Adam(self.model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
282289

283290
for Xbatch, ybatch in batch_loader(X, y, batch_size=batch_size):
284291
opt.zero_grad()
285292
output_states = self.model(Xbatch)[-1]
286293
loss = self.loss(output_states, ybatch)
294+
295+
# Apply L1 regularization if selected
296+
if reg_type == "l1":
297+
l1_loss = 0
298+
for param in self.model.parameters():
299+
l1_loss += torch.sum(torch.abs(param))
300+
loss += reg_strength * l1_loss
301+
287302
loss.backward()
288303
opt.step()
289304

0 commit comments

Comments
 (0)