Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,34 +145,36 @@ jobs:
- test: TestAarch64
instance: aarch64
failure-prefix: aarch64
- test: TestHfHandler
# - test: TestHfHandler
# instance: g6
# failure-prefix: lmi
# - test: TestTrtLlmHandler1
# instance: g6
# failure-prefix: trtllm
# - test: TestTrtLlmHandler2
# instance: g6
# failure-prefix: trtllm
- test: TestVllm1
instance: g6
failure-prefix: lmi
- test: TestTrtLlmHandler1
instance: g6
failure-prefix: trtllm
- test: TestTrtLlmHandler2
- test: TestVllmCustomHandlers
instance: g6
failure-prefix: trtllm
- test: TestVllm1
failure-prefix: lmi
- test: TestVllmLora
instance: g6
failure-prefix: lmi
- test: TestVllmCustomHandlers
- test: TestVllmAsyncLora
instance: g6
failure-prefix: lmi
# - test: TestVllmLora
# instance: g6
# failure-prefix: lmi

- test: TestMultiModalVllm
instance: g6
failure-prefix: lmi
# - test: TestTextEmbedding
# instance: g6
# failure-prefix: lmi
- test: TestCorrectnessTrtLlm
instance: g6
failure-prefix: trtllm
# - test: TestCorrectnessTrtLlm
# instance: g6
# failure-prefix: trtllm
- test: TestStatefulModel
instance: g6
failure-prefix: lmi
Expand Down
159 changes: 159 additions & 0 deletions engines/python/setup/djl_python/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# the specific language governing permissions and limitations under the License.
import json
import logging
import os
from typing import AsyncGenerator, Callable, Optional, Union

from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.input_parser import SAGEMAKER_ADAPTER_IDENTIFIER_HEADER


def create_non_stream_output(data: Union[str, dict],
Expand Down Expand Up @@ -112,3 +115,159 @@ async def handle_streaming_response(
yield output
if last:
return


def _extract_lora_adapter(raw_request, decoded_payload):
"""
Get lora adapter name from request headers or payload.
"""
adapter_name = None

if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in raw_request.get_properties():
adapter_name = raw_request.get_property(
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
logging.debug(f"Found adapter in headers: {adapter_name}")
elif "adapter" in decoded_payload:
adapter_name = decoded_payload.pop("adapter")
logging.debug(f"Found adapter in payload: {adapter_name}")

return adapter_name


async def register_adapter(inputs: Input, service):
"""
Registers lora adapter with the model.
"""
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_preload = inputs.get_as_string("preload").lower(
) == "true" if inputs.contains_key("preload") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

outputs = Output()
loaded = False
try:
if not os.path.exists(adapter_path):
raise ValueError(
f"Only local LoRA models are supported. {adapter_path} is not a valid path"
)

if not adapter_preload and adapter_pin:
raise ValueError("Can not set preload to false and pin to true")

if adapter_preload:
loaded = await service.add_lora(adapter_name, adapter_alias,
adapter_path)

if adapter_pin:
await service.pin_lora(adapter_name, adapter_alias)
service.adapter_registry[adapter_name] = inputs
except Exception as e:
logging.debug(f"Failed to register adapter: {e}", exc_info=True)
if loaded:
logging.info(
f"LoRA adapter {adapter_alias} was successfully loaded, but failed to pin, unloading ..."
)
await service.remove_lora(adapter_name, adapter_alias)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
raise MemoryError(str(e))
err = {"data": "", "last": True, "code": 424, "error": str(e)}
outputs.add(Output.binary_encode(err), key="data")
return outputs

logging.info(
f"Registered adapter {adapter_alias} from {adapter_path} successfully")
result = {"data": f"Adapter {adapter_alias} registered"}
outputs.add(Output.binary_encode(result), key="data")
return outputs


async def update_adapter(inputs: Input, service):
"""
Updates lora adapter with the model.
"""
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name
adapter_path = inputs.get_property("src")
adapter_preload = inputs.get_as_string("preload").lower(
) == "true" if inputs.contains_key("preload") else True
adapter_pin = inputs.get_as_string(
"pin").lower() == "true" if inputs.contains_key("pin") else False

if adapter_name not in service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

outputs = Output()
try:
if not adapter_preload and adapter_pin:
raise ValueError("Can not set load to false and pin to true")

old_adapter = service.adapter_registry[adapter_name]
old_adapter_path = old_adapter.get_property("src")
if adapter_path != old_adapter_path:
raise NotImplementedError(
f"Updating adapter path is not supported.")

old_adapter_preload = old_adapter.get_as_string("preload").lower(
) == "true" if old_adapter.contains_key("preload") else True
if adapter_preload != old_adapter_preload:
if adapter_preload:
await service.add_lora(adapter_name, adapter_alias,
adapter_path)
else:
await service.remove_lora(adapter_name, adapter_alias)

old_adapter_pin = old_adapter.get_as_string("pin").lower(
) == "true" if old_adapter.contains_key("pin") else False
if adapter_pin != old_adapter_pin:
if adapter_pin:
await service.pin_lora(adapter_name, adapter_alias)
else:
raise ValueError(
f"Unpinning adapter is not supported. To unpin adapter '{adapter_alias}', please delete the adapter and re-register it without pinning."
)
service.adapter_registry[adapter_name] = inputs
except Exception as e:
logging.debug(f"Failed to update adapter: {e}", exc_info=True)
if any(msg in str(e)
for msg in ("No free lora slots",
"greater than the number of GPU LoRA slots")):
raise MemoryError(str(e))
err = {"data": "", "last": True, "code": 424, "error": str(e)}
outputs.add(Output.binary_encode(err), key="data")
return outputs

logging.info(f"Updated adapter {adapter_alias} successfully")
result = {"data": f"Adapter {adapter_alias} updated"}
outputs.add(Output.binary_encode(result), key="data")
return outputs


async def unregister_adapter(inputs: Input, service):
"""
Unregisters lora adapter from the model.
"""
adapter_name = inputs.get_property("name")
adapter_alias = inputs.get_property("alias") or adapter_name

if adapter_name not in service.adapter_registry:
raise ValueError(f"Adapter {adapter_alias} not registered.")

outputs = Output()
try:
await service.remove_lora(adapter_name, adapter_alias)
del service.adapter_registry[adapter_name]
except Exception as e:
logging.debug(f"Failed to unregister adapter: {e}", exc_info=True)
err = {"data": "", "last": True, "code": 424, "error": str(e)}
outputs.add(Output.binary_encode(err), key="data")
return outputs

logging.info(f"Unregistered adapter {adapter_alias} successfully")
result = {"data": f"Adapter {adapter_alias} unregistered"}
outputs.add(Output.binary_encode(result), key="data")
return outputs
4 changes: 3 additions & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,9 @@ def update_adapter(inputs: Input):
if adapter_pin:
_service.pin_lora(adapter_name, adapter_alias)
else:
raise NotImplementedError(f"Unpin adapter is not supported.")
raise ValueError(
f"Unpinning adapter is not supported. To unpin adapter '{adapter_alias}', please delete the adapter and re-register it without pinning."
)
_service.adapter_registry[adapter_name] = inputs
except Exception as e:
logging.debug(f"Failed to update adapter: {e}", exc_info=True)
Expand Down
6 changes: 4 additions & 2 deletions engines/python/setup/djl_python/input_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from djl_python.request_io import TextInput, RequestInput
from djl_python.three_p.three_p_utils import parse_3p_request

SAGEMAKER_ADAPTER_IDENTIFIER_HEADER = "X-Amzn-SageMaker-Adapter-Identifier"


def input_formatter(function):
"""
Expand Down Expand Up @@ -237,9 +239,9 @@ def _fetch_adapters_from_input(input_map: dict, input_item: Input,

# check properties, possible from header
adapter_alias = None
if "X-Amzn-SageMaker-Adapter-Identifier" in input_item.get_properties():
if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in input_item.get_properties():
adapters_per_item = input_item.get_property(
"X-Amzn-SageMaker-Adapter-Identifier")
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
adapter_alias = input_item.get_property(
"X-Amzn-SageMaker-Adapter-Alias")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.stream_output_formatter = stream_output_formatter
self.accumulate_chunks = accumulate_chunks
self.include_prompt = include_prompt
self.lora_request = None


def convert_lmi_schema_to_completion_request(
Expand Down
Loading
Loading