@@ -4501,3 +4501,138 @@ def test_create_inference_endpoint_from_catalog(self, mock_get_session: Mock) ->
45014501 )
45024502 assert isinstance (endpoint , InferenceEndpoint )
45034503 assert endpoint .name == "llama-3-2-3b-instruct-eey"
4504+
4505+
4506+ @pytest .mark .parametrize (
4507+ "custom_image, expected_image_payload" ,
4508+ [
4509+ # Case 1: No custom_image provided
4510+ (
4511+ None ,
4512+ {
4513+ "huggingface" : {},
4514+ },
4515+ ),
4516+ # Case 2: Flat dictionary custom_image provided
4517+ (
4518+ {
4519+ "url" : "my.registry/my-image:latest" ,
4520+ "port" : 8080 ,
4521+ },
4522+ {
4523+ "custom" : {
4524+ "url" : "my.registry/my-image:latest" ,
4525+ "port" : 8080 ,
4526+ }
4527+ },
4528+ ),
4529+ # Case 3: Explicitly keyed ('tgi') custom_image provided
4530+ (
4531+ {
4532+ "tgi" : {
4533+ "url" : "ghcr.io/huggingface/text-generation-inference:latest" ,
4534+ }
4535+ },
4536+ {
4537+ "tgi" : {
4538+ "url" : "ghcr.io/huggingface/text-generation-inference:latest" ,
4539+ }
4540+ },
4541+ ),
4542+ # Case 4: Explicitly keyed ('custom') custom_image provided
4543+ (
4544+ {
4545+ "custom" : {
4546+ "url" : "another.registry/custom:v2" ,
4547+ }
4548+ },
4549+ {
4550+ "custom" : {
4551+ "url" : "another.registry/custom:v2" ,
4552+ }
4553+ },
4554+ ),
4555+ ],
4556+ ids = ["no_custom_image" , "flat_dict_custom_image" , "keyed_tgi_custom_image" , "keyed_custom_custom_image" ],
4557+ )
4558+ @patch ("huggingface_hub.hf_api.get_session" )
4559+ def test_create_inference_endpoint_custom_image_payload (
4560+ mock_post : Mock ,
4561+ custom_image : Optional [dict ],
4562+ expected_image_payload : dict ,
4563+ ):
4564+ common_args = {
4565+ "name" : "test-endpoint-custom-img" ,
4566+ "repository" : "meta-llama/Llama-2-7b-chat-hf" ,
4567+ "framework" : "pytorch" ,
4568+ "accelerator" : "gpu" ,
4569+ "instance_size" : "medium" ,
4570+ "instance_type" : "nvidia-a10g" ,
4571+ "region" : "us-east-1" ,
4572+ "vendor" : "aws" ,
4573+ "type" : "protected" ,
4574+ "task" : "text-generation" ,
4575+ "namespace" : "Wauplin" ,
4576+ }
4577+ mock_session = mock_post .return_value
4578+ mock_post_method = mock_session .post
4579+ mock_response = Mock ()
4580+ mock_response .raise_for_status .return_value = None
4581+ mock_response .json .return_value = {
4582+ "compute" : {
4583+ "accelerator" : "gpu" ,
4584+ "id" : "aws-us-east-1-nvidia-l4-x1" ,
4585+ "instanceSize" : "x1" ,
4586+ "instanceType" : "nvidia-l4" ,
4587+ "scaling" : {
4588+ "maxReplica" : 1 ,
4589+ "measure" : {"hardwareUsage" : None },
4590+ "metric" : "hardwareUsage" ,
4591+ "minReplica" : 0 ,
4592+ "scaleToZeroTimeout" : 15 ,
4593+ },
4594+ },
4595+ "model" : {
4596+ "env" : {},
4597+ "framework" : "pytorch" ,
4598+ "image" : {
4599+ "tgi" : {
4600+ "disableCustomKernels" : False ,
4601+ "healthRoute" : "/health" ,
4602+ "port" : 80 ,
4603+ "url" : "ghcr.io/huggingface/text-generation-inference:3.1.1" ,
4604+ }
4605+ },
4606+ "repository" : "meta-llama/Llama-3.2-3B-Instruct" ,
4607+ "revision" : "0cb88a4f764b7a12671c53f0838cd831a0843b95" ,
4608+ "secrets" : {},
4609+ "task" : "text-generation" ,
4610+ },
4611+ "name" : "llama-3-2-3b-instruct-eey" ,
4612+ "provider" : {"region" : "us-east-1" , "vendor" : "aws" },
4613+ "status" : {
4614+ "createdAt" : "2025-03-07T15:30:13.949Z" ,
4615+ "createdBy" : {"id" : "6273f303f6d63a28483fde12" , "name" : "Wauplin" },
4616+ "message" : "Endpoint waiting to be scheduled" ,
4617+ "readyReplica" : 0 ,
4618+ "state" : "pending" ,
4619+ "targetReplica" : 1 ,
4620+ "updatedAt" : "2025-03-07T15:30:13.949Z" ,
4621+ "updatedBy" : {"id" : "6273f303f6d63a28483fde12" , "name" : "Wauplin" },
4622+ },
4623+ "type" : "protected" ,
4624+ }
4625+ mock_post_method .return_value = mock_response
4626+
4627+ api = HfApi (endpoint = ENDPOINT_STAGING , token = TOKEN )
4628+ if custom_image is not None :
4629+ api .create_inference_endpoint (custom_image = custom_image , ** common_args )
4630+ else :
4631+ api .create_inference_endpoint (** common_args )
4632+
4633+ mock_post_method .assert_called_once ()
4634+ _ , call_kwargs = mock_post_method .call_args
4635+ payload = call_kwargs .get ("json" , {})
4636+
4637+ assert "model" in payload and "image" in payload ["model" ]
4638+ assert payload ["model" ]["image" ] == expected_image_payload
0 commit comments