"udf": "\nimport functools\nfrom typing import Dict\nimport sys\nimport numpy as np\nimport xarray as xr\nfrom openeo.udf import inspect\n\nsys.path.append(\"onnx_deps\") \nsys.path.append(\"onnx_models\") \nimport onnxruntime as ort\n\n\n\n@functools.lru_cache(maxsize=1)\ndef load_onnx_model(model_name: str) -> ort.InferenceSession:\n \"\"\"\n Loads an ONNX model from the onnx_models folder and returns an ONNX runtime session.\n\n \"\"\"\n # The onnx_models folder contains the content of the model archive provided in the job options\n return ort.InferenceSession(f\"onnx_models/{model_name}\")\n\ndef preprocess_input(\n input_xr: xr.DataArray, ort_session: ort.InferenceSession\n) -> tuple:\n \"\"\"\n Preprocess the input DataArray by ensuring the dimensions are in the correct order,\n reshaping it, and returning the reshaped numpy array and the original shape.\n \"\"\"\n input_xr = input_xr.transpose(\"y\", \"x\", \"bands\")\n input_shape = input_xr.shape\n input_np = input_xr.values.reshape(-1, ort_session.get_inputs()[0].shape[1])\n input_np = input_np.astype(np.float32)\n return input_np, input_shape\n\n\ndef run_inference(input_np: np.ndarray, ort_session: ort.InferenceSession) -> tuple:\n \"\"\"\n Run inference using the ONNX runtime session and return predicted labels and probabilities.\n \"\"\"\n ort_inputs = {ort_session.get_inputs()[0].name: input_np}\n ort_outputs = ort_session.run(None, ort_inputs)\n predicted_labels = ort_outputs[0]\n return predicted_labels\n\n\ndef postprocess_output(predicted_labels: np.ndarray, input_shape: tuple) -> tuple:\n \"\"\"\n Postprocess the output by reshaping the predicted labels and probabilities into the original spatial structure.\n \"\"\"\n predicted_labels = predicted_labels.reshape(input_shape[0], input_shape[1])\n\n return predicted_labels\n\n\ndef create_output_xarray(\n predicted_labels: np.ndarray, input_xr: xr.DataArray\n) -> xr.DataArray:\n \"\"\"\n Create an xarray DataArray with predicted labels and probabilities stacked along the bands dimension.\n \"\"\"\n\n return xr.DataArray(\n predicted_labels,\n dims=[\"y\", \"x\"],\n coords={\"y\": input_xr.coords[\"y\"], \"x\": input_xr.coords[\"x\"]},\n )\n\n\ndef apply_model(input_xr: xr.DataArray) -> xr.DataArray:\n \"\"\"\n Run inference on the given input data using the provided ONNX runtime session.\n This method is called for each timestep in the chunk received by apply_datacube.\n \"\"\"\n\n # Step 1: Load the ONNX model\n inspect(message=\"load onnx model\")\n ort_session = load_onnx_model(\"rf_1_median_depth_15.onnx\")\n\n # Step 2: Preprocess the input\n inspect(message=\"preprocess input\")\n input_np, input_shape = preprocess_input(input_xr, ort_session)\n\n # Step 3: Perform inference\n inspect(message=\"run model inference\")\n predicted_labels = run_inference(input_np, ort_session)\n\n # Step 4: Postprocess the output\n inspect(message=\"post process output\")\n predicted_labels = postprocess_output(predicted_labels, input_shape)\n\n # Step 5: Create the output xarray\n inspect(message=\"create output xarray\")\n return create_output_xarray(predicted_labels, input_xr)\n\n\ndef apply_datacube(cube: xr.DataArray, context: Dict) -> xr.DataArray:\n \"\"\"\n Function that is called for each chunk of data that is processed.\n The function name and arguments are defined by the UDF API.\n \"\"\"\n # Define how you want to handle nan values\n cube = cube.fillna(-999999)\n\n # Apply the model for each timestep in the chunk\n output_data = apply_model(cube)\n\n return output_data\n"
0 commit comments