|
| 1 | + |
1 | 2 | import functools |
2 | | -import logging |
3 | | -import os |
4 | | -import sys |
5 | | -import zipfile |
6 | 3 | from typing import Dict |
7 | | - |
| 4 | +import sys |
8 | 5 | import numpy as np |
9 | | -import requests |
10 | 6 | import xarray as xr |
11 | | - |
12 | 7 | from openeo.udf import inspect |
13 | 8 |
|
| 9 | +sys.path.append("onnx_deps") |
| 10 | +import onnxruntime as ort |
14 | 11 |
|
15 | 12 |
|
16 | | -# TODO move standard code to UDF repo |
17 | 13 |
|
18 | | -# Fixed directories for dependencies and model files |
19 | | -DEPENDENCIES_DIR = "onnx_dependencies" |
20 | | -MODEL_DIR = "model_files" |
21 | | - |
22 | | - |
23 | | -def download_file(url, path): |
24 | | - """ |
25 | | - Downloads a file from the given URL to the specified path. |
26 | | - """ |
27 | | - response = requests.get(url, stream=True) |
28 | | - with open(path, "wb") as file: |
29 | | - file.write(response.content) |
30 | | - |
31 | | - |
32 | | -def extract_zip(zip_path, extract_to): |
33 | | - """ |
34 | | - Extracts a zip file from zip_path to the specified extract_to directory. |
35 | | - """ |
36 | | - with zipfile.ZipFile(zip_path, "r") as zip_ref: |
37 | | - zip_ref.extractall(extract_to) |
38 | | - os.remove(zip_path) # Clean up the zip file after extraction |
39 | | - |
40 | | - |
41 | | -def add_directory_to_sys_path(directory): |
42 | | - """ |
43 | | - Adds a directory to the Python sys.path if it's not already present. |
44 | | - """ |
45 | | - if directory not in sys.path: |
46 | | - sys.path.append(directory) |
47 | | - |
48 | | -@functools.lru_cache(maxsize=5) |
49 | | -def setup_model_and_dependencies(model_url, dependencies_url): |
50 | | - """ |
51 | | - Main function to set up the model and dependencies by downloading, extracting, |
52 | | - and adding necessary directories to sys.path. |
53 | | - """ |
54 | | - |
55 | | - inspect(message="Create directories") |
56 | | - # Ensure base directories exist |
57 | | - os.makedirs(DEPENDENCIES_DIR, exist_ok=True) |
58 | | - os.makedirs(MODEL_DIR, exist_ok=True) |
59 | | - |
60 | | - # Download and extract dependencies if not already present |
61 | | - if not os.listdir(DEPENDENCIES_DIR): |
62 | | - |
63 | | - inspect(message="Extract dependencies") |
64 | | - zip_path = os.path.join(DEPENDENCIES_DIR, "temp.zip") |
65 | | - download_file(dependencies_url, zip_path) |
66 | | - extract_zip(zip_path, DEPENDENCIES_DIR) |
67 | | - |
68 | | - # Add the extracted dependencies directory to sys.path |
69 | | - add_directory_to_sys_path(DEPENDENCIES_DIR) |
70 | | - |
71 | | - # Download and extract model if not already present |
72 | | - if not os.listdir(MODEL_DIR): |
73 | | - |
74 | | - inspect(message="Extract model") |
75 | | - zip_path = os.path.join(MODEL_DIR, "temp.zip") |
76 | | - download_file(model_url, zip_path) |
77 | | - extract_zip(zip_path, MODEL_DIR) |
78 | | - |
79 | | - |
80 | | -setup_model_and_dependencies( |
81 | | - model_url="https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/EURAC_pvfarm_rf_1_median_depth_15.zip", |
82 | | - dependencies_url="https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies/onnx_dependencies_1.16.3.zip", |
83 | | -) |
84 | | - |
85 | | -# Add dependencies to the Python path |
86 | | -import onnxruntime as ort # Import after downloading dependencies |
87 | | - |
88 | | - |
89 | | -@functools.lru_cache(maxsize=5) |
| 14 | +@functools.lru_cache(maxsize=1) |
90 | 15 | def load_onnx_model(model_name: str) -> ort.InferenceSession: |
91 | 16 | """ |
92 | 17 | Loads an ONNX model from the onnx_models folder and returns an ONNX runtime session. |
93 | 18 |
|
94 | 19 | """ |
95 | 20 | # The onnx_models folder contains the content of the model archive provided in the job options |
96 | | - return ort.InferenceSession( |
97 | | - f"{MODEL_DIR}/{model_name}", providers=["CPUExecutionProvider"] |
98 | | - ) |
99 | | - |
| 21 | + return ort.InferenceSession(f"onnx_models/{model_name}") |
100 | 22 |
|
101 | 23 | def preprocess_input( |
102 | 24 | input_xr: xr.DataArray, ort_session: ort.InferenceSession |
|
0 commit comments