|
26 | 26 | from datetime import datetime |
27 | 27 |
|
28 | 28 | import requests |
29 | | -from vertexai.preview.evaluation import MetricPromptTemplateExamples |
30 | 29 |
|
31 | 30 | try: |
32 | 31 | from airflow.sdk import task |
@@ -166,17 +165,30 @@ def _get_actual_models(key) -> dict[str, str]: |
166 | 165 | "Baking a decadent chocolate cake requires creaming butter and sugar, beating in eggs and alternating dry ingredients with buttermilk before baking until done.", |
167 | 166 | ], |
168 | 167 | } |
169 | | -METRICS = [ |
170 | | - MetricPromptTemplateExamples.Pointwise.SUMMARIZATION_QUALITY, |
171 | | - MetricPromptTemplateExamples.Pointwise.GROUNDEDNESS, |
172 | | - MetricPromptTemplateExamples.Pointwise.VERBOSITY, |
173 | | - MetricPromptTemplateExamples.Pointwise.INSTRUCTION_FOLLOWING, |
174 | | - "exact_match", |
175 | | - "bleu", |
176 | | - "rouge_1", |
177 | | - "rouge_2", |
178 | | - "rouge_l_sum", |
179 | | -] |
| 168 | + |
| 169 | + |
| 170 | +def _get_metrics(): |
| 171 | + """ |
| 172 | + Lazily import and return the metrics list. |
| 173 | +
|
| 174 | + This avoids slow imports during DAG parsing by deferring the import |
| 175 | + until the operator is actually created. |
| 176 | + """ |
| 177 | + from vertexai.preview.evaluation import MetricPromptTemplateExamples |
| 178 | + |
| 179 | + return [ |
| 180 | + MetricPromptTemplateExamples.Pointwise.SUMMARIZATION_QUALITY, |
| 181 | + MetricPromptTemplateExamples.Pointwise.GROUNDEDNESS, |
| 182 | + MetricPromptTemplateExamples.Pointwise.VERBOSITY, |
| 183 | + MetricPromptTemplateExamples.Pointwise.INSTRUCTION_FOLLOWING, |
| 184 | + "exact_match", |
| 185 | + "bleu", |
| 186 | + "rouge_1", |
| 187 | + "rouge_2", |
| 188 | + "rouge_l_sum", |
| 189 | + ] |
| 190 | + |
| 191 | + |
180 | 192 | EXPERIMENT_NAME = f"eval-test-experiment-airflow-operator-{ENV_ID}".replace("_", "-") |
181 | 193 | EXPERIMENT_RUN_NAME = f"eval-experiment-airflow-operator-run-{ENV_ID}".replace("_", "-") |
182 | 194 | PROMPT_TEMPLATE = "{instruction}. Article: {context}. Summary:" |
@@ -281,7 +293,7 @@ def get_actual_models(key): |
281 | 293 | location=REGION, |
282 | 294 | pretrained_model=MULTIMODAL_MODEL, |
283 | 295 | eval_dataset=EVAL_DATASET, |
284 | | - metrics=METRICS, |
| 296 | + metrics=_get_metrics(), |
285 | 297 | experiment_name=EXPERIMENT_NAME, |
286 | 298 | experiment_run_name=EXPERIMENT_RUN_NAME, |
287 | 299 | prompt_template=PROMPT_TEMPLATE, |
|
0 commit comments