|
12 | 12 | # the specific language governing permissions and limitations under the License. |
13 | 13 | import json |
14 | 14 | import logging |
| 15 | +import os |
15 | 16 | from typing import AsyncGenerator, Callable, Optional, Union |
16 | 17 |
|
| 18 | +from djl_python.inputs import Input |
17 | 19 | from djl_python.outputs import Output |
| 20 | +from djl_python.input_parser import SAGEMAKER_ADAPTER_IDENTIFIER_HEADER |
18 | 21 |
|
19 | 22 |
|
20 | 23 | def create_non_stream_output(data: Union[str, dict], |
@@ -112,3 +115,159 @@ async def handle_streaming_response( |
112 | 115 | yield output |
113 | 116 | if last: |
114 | 117 | return |
| 118 | + |
| 119 | + |
| 120 | +def _extract_lora_adapter(raw_request, decoded_payload): |
| 121 | + """ |
| 122 | + Get lora adapter name from request headers or payload. |
| 123 | + """ |
| 124 | + adapter_name = None |
| 125 | + |
| 126 | + if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in raw_request.get_properties(): |
| 127 | + adapter_name = raw_request.get_property( |
| 128 | + SAGEMAKER_ADAPTER_IDENTIFIER_HEADER) |
| 129 | + logging.debug(f"Found adapter in headers: {adapter_name}") |
| 130 | + elif "adapter" in decoded_payload: |
| 131 | + adapter_name = decoded_payload.pop("adapter") |
| 132 | + logging.debug(f"Found adapter in payload: {adapter_name}") |
| 133 | + |
| 134 | + return adapter_name |
| 135 | + |
| 136 | + |
| 137 | +async def register_adapter(inputs: Input, service): |
| 138 | + """ |
| 139 | + Registers lora adapter with the model. |
| 140 | + """ |
| 141 | + adapter_name = inputs.get_property("name") |
| 142 | + adapter_alias = inputs.get_property("alias") or adapter_name |
| 143 | + adapter_path = inputs.get_property("src") |
| 144 | + adapter_preload = inputs.get_as_string("preload").lower( |
| 145 | + ) == "true" if inputs.contains_key("preload") else True |
| 146 | + adapter_pin = inputs.get_as_string( |
| 147 | + "pin").lower() == "true" if inputs.contains_key("pin") else False |
| 148 | + |
| 149 | + outputs = Output() |
| 150 | + loaded = False |
| 151 | + try: |
| 152 | + if not os.path.exists(adapter_path): |
| 153 | + raise ValueError( |
| 154 | + f"Only local LoRA models are supported. {adapter_path} is not a valid path" |
| 155 | + ) |
| 156 | + |
| 157 | + if not adapter_preload and adapter_pin: |
| 158 | + raise ValueError("Can not set preload to false and pin to true") |
| 159 | + |
| 160 | + if adapter_preload: |
| 161 | + loaded = await service.add_lora(adapter_name, adapter_alias, |
| 162 | + adapter_path) |
| 163 | + |
| 164 | + if adapter_pin: |
| 165 | + await service.pin_lora(adapter_name, adapter_alias) |
| 166 | + service.adapter_registry[adapter_name] = inputs |
| 167 | + except Exception as e: |
| 168 | + logging.debug(f"Failed to register adapter: {e}", exc_info=True) |
| 169 | + if loaded: |
| 170 | + logging.info( |
| 171 | + f"LoRA adapter {adapter_alias} was successfully loaded, but failed to pin, unloading ..." |
| 172 | + ) |
| 173 | + await service.remove_lora(adapter_name, adapter_alias) |
| 174 | + if any(msg in str(e) |
| 175 | + for msg in ("No free lora slots", |
| 176 | + "greater than the number of GPU LoRA slots")): |
| 177 | + raise MemoryError(str(e)) |
| 178 | + err = {"data": "", "last": True, "code": 424, "error": str(e)} |
| 179 | + outputs.add(Output.binary_encode(err), key="data") |
| 180 | + return outputs |
| 181 | + |
| 182 | + logging.info( |
| 183 | + f"Registered adapter {adapter_alias} from {adapter_path} successfully") |
| 184 | + result = {"data": f"Adapter {adapter_alias} registered"} |
| 185 | + outputs.add(Output.binary_encode(result), key="data") |
| 186 | + return outputs |
| 187 | + |
| 188 | + |
| 189 | +async def update_adapter(inputs: Input, service): |
| 190 | + """ |
| 191 | + Updates lora adapter with the model. |
| 192 | + """ |
| 193 | + adapter_name = inputs.get_property("name") |
| 194 | + adapter_alias = inputs.get_property("alias") or adapter_name |
| 195 | + adapter_path = inputs.get_property("src") |
| 196 | + adapter_preload = inputs.get_as_string("preload").lower( |
| 197 | + ) == "true" if inputs.contains_key("preload") else True |
| 198 | + adapter_pin = inputs.get_as_string( |
| 199 | + "pin").lower() == "true" if inputs.contains_key("pin") else False |
| 200 | + |
| 201 | + if adapter_name not in service.adapter_registry: |
| 202 | + raise ValueError(f"Adapter {adapter_alias} not registered.") |
| 203 | + |
| 204 | + outputs = Output() |
| 205 | + try: |
| 206 | + if not adapter_preload and adapter_pin: |
| 207 | + raise ValueError("Can not set load to false and pin to true") |
| 208 | + |
| 209 | + old_adapter = service.adapter_registry[adapter_name] |
| 210 | + old_adapter_path = old_adapter.get_property("src") |
| 211 | + if adapter_path != old_adapter_path: |
| 212 | + raise NotImplementedError( |
| 213 | + f"Updating adapter path is not supported.") |
| 214 | + |
| 215 | + old_adapter_preload = old_adapter.get_as_string("preload").lower( |
| 216 | + ) == "true" if old_adapter.contains_key("preload") else True |
| 217 | + if adapter_preload != old_adapter_preload: |
| 218 | + if adapter_preload: |
| 219 | + await service.add_lora(adapter_name, adapter_alias, |
| 220 | + adapter_path) |
| 221 | + else: |
| 222 | + await service.remove_lora(adapter_name, adapter_alias) |
| 223 | + |
| 224 | + old_adapter_pin = old_adapter.get_as_string("pin").lower( |
| 225 | + ) == "true" if old_adapter.contains_key("pin") else False |
| 226 | + if adapter_pin != old_adapter_pin: |
| 227 | + if adapter_pin: |
| 228 | + await service.pin_lora(adapter_name, adapter_alias) |
| 229 | + else: |
| 230 | + raise ValueError( |
| 231 | + f"Unpinning adapter is not supported. To unpin adapter '{adapter_alias}', please delete the adapter and re-register it without pinning." |
| 232 | + ) |
| 233 | + service.adapter_registry[adapter_name] = inputs |
| 234 | + except Exception as e: |
| 235 | + logging.debug(f"Failed to update adapter: {e}", exc_info=True) |
| 236 | + if any(msg in str(e) |
| 237 | + for msg in ("No free lora slots", |
| 238 | + "greater than the number of GPU LoRA slots")): |
| 239 | + raise MemoryError(str(e)) |
| 240 | + err = {"data": "", "last": True, "code": 424, "error": str(e)} |
| 241 | + outputs.add(Output.binary_encode(err), key="data") |
| 242 | + return outputs |
| 243 | + |
| 244 | + logging.info(f"Updated adapter {adapter_alias} successfully") |
| 245 | + result = {"data": f"Adapter {adapter_alias} updated"} |
| 246 | + outputs.add(Output.binary_encode(result), key="data") |
| 247 | + return outputs |
| 248 | + |
| 249 | + |
| 250 | +async def unregister_adapter(inputs: Input, service): |
| 251 | + """ |
| 252 | + Unregisters lora adapter from the model. |
| 253 | + """ |
| 254 | + adapter_name = inputs.get_property("name") |
| 255 | + adapter_alias = inputs.get_property("alias") or adapter_name |
| 256 | + |
| 257 | + if adapter_name not in service.adapter_registry: |
| 258 | + raise ValueError(f"Adapter {adapter_alias} not registered.") |
| 259 | + |
| 260 | + outputs = Output() |
| 261 | + try: |
| 262 | + await service.remove_lora(adapter_name, adapter_alias) |
| 263 | + del service.adapter_registry[adapter_name] |
| 264 | + except Exception as e: |
| 265 | + logging.debug(f"Failed to unregister adapter: {e}", exc_info=True) |
| 266 | + err = {"data": "", "last": True, "code": 424, "error": str(e)} |
| 267 | + outputs.add(Output.binary_encode(err), key="data") |
| 268 | + return outputs |
| 269 | + |
| 270 | + logging.info(f"Unregistered adapter {adapter_alias} successfully") |
| 271 | + result = {"data": f"Adapter {adapter_alias} unregistered"} |
| 272 | + outputs.add(Output.binary_encode(result), key="data") |
| 273 | + return outputs |
0 commit comments