@@ -11,6 +11,7 @@ import (
1111	"github.com/NexusGPU/tensor-fusion/internal/constants" 
1212	"github.com/NexusGPU/tensor-fusion/internal/gpuallocator" 
1313	"github.com/NexusGPU/tensor-fusion/internal/gpuallocator/filter" 
14+ 	"github.com/NexusGPU/tensor-fusion/internal/utils" 
1415	"github.com/samber/lo/mutable" 
1516	corev1 "k8s.io/api/core/v1" 
1617	errors "k8s.io/apimachinery/pkg/api/errors" 
@@ -277,12 +278,15 @@ func (e *NodeExpander) simulateSchedulingWithoutGPU(ctx context.Context, pod *co
277278		return  nil , fmt .Errorf ("pod labels is nil, pod: %s" , pod .Name )
278279	}
279280
280- 	// Disable the tensor fusion label to simulate scheduling without GPU plugins 
281+ 	// Disable the tensor fusion component  label to simulate scheduling without GPU plugins 
281282	// NOTE: must apply patch after `go mod vendor`, FindNodesThatFitPod is not exported from Kubernetes 
282283	// Run `git apply ./patches/scheduler-sched-one.patch` once or `bash scripts/patch-scheduler.sh` 
283- 	pod .Labels [constants .TensorFusionEnabledLabelKey ] =  constants .FalseStringValue 
284+ 	if  ! utils .IsTensorFusionPod (pod ) {
285+ 		return  nil , fmt .Errorf ("pod to check expansion is not a tensor fusion worker pod: %s" , pod .Name )
286+ 	}
287+ 	delete (pod .Labels , constants .LabelComponent )
284288	scheduleResult , _ , err  :=  e .scheduler .FindNodesThatFitPod (ctx , fwkInstance , state , pod )
285- 	pod .Labels [constants .TensorFusionEnabledLabelKey ] =  constants .TrueStringValue 
289+ 	pod .Labels [constants .LabelComponent ] =  constants .ComponentWorker 
286290	if  len (scheduleResult ) ==  0  {
287291		return  nil , err 
288292	}
@@ -382,32 +386,34 @@ func (e *NodeExpander) checkGPUFitForNewNode(pod *corev1.Pod, gpus []*tfv1.GPU)
382386
383387func  (e  * NodeExpander ) createGPUNodeClaim (ctx  context.Context , pod  * corev1.Pod , preparedNode  * corev1.Node ) error  {
384388	owners  :=  preparedNode .GetOwnerReferences ()
389+ 	isKarpenterNodeClaim  :=  false 
390+ 	isGPUNodeClaim  :=  false 
385391	controlledBy  :=  & metav1.OwnerReference {}
386392	for  _ , owner  :=  range  owners  {
387- 		if  owner .Controller  !=  nil  &&  * owner .Controller  {
388- 			controlledBy  =  & owner 
393+ 		controlledBy  =  & owner 
394+ 		// Karpenter owner reference is not controller reference 
395+ 		if  owner .Kind  ==  constants .KarpenterNodeClaimKind  {
396+ 			isKarpenterNodeClaim  =  true 
397+ 			break 
398+ 		} else  if  owner .Kind  ==  tfv1 .GPUNodeClaimKind  {
399+ 			isGPUNodeClaim  =  true 
389400			break 
390401		}
391402	}
392- 	if  controlledBy . Kind   ==   ""  {
393- 		e .logger .Info ("node is not owned by any provisioner, skip expansion" , "node" , preparedNode .Name )
403+ 	if  ! isKarpenterNodeClaim   &&   ! isGPUNodeClaim  {
404+ 		e .logger .Info ("node is not owned by any known  provisioner, skip expansion" , "node" , preparedNode .Name )
394405		return  nil 
395406	}
396407	e .logger .Info ("start expanding node from existing template node" , "tmplNode" , preparedNode .Name )
397- 
398- 	switch  controlledBy .Kind  {
399- 	case  constants .KarpenterNodeClaimKind :
408+ 	if  isKarpenterNodeClaim  {
400409		// Check if controllerMeta's parent is GPUNodeClaim using unstructured object 
401410		return  e .handleKarpenterNodeClaim (ctx , pod , preparedNode , controlledBy )
402- 	case   tfv1 . GPUNodeClaimKind : 
411+ 	}  else   if   isGPUNodeClaim  { 
403412		// Running in Provisioning mode, clone the parent GPUNodeClaim and apply 
404413		e .logger .Info ("node is controlled by GPUNodeClaim, cloning another to expand node" , "tmplNode" , preparedNode .Name )
405414		return  e .cloneGPUNodeClaim (ctx , pod , preparedNode , controlledBy )
406- 	default :
407- 		e .logger .Info ("node is not controlled by any known provisioner, skip expansion" , "tmplNode" , preparedNode .Name ,
408- 			"controller" , controlledBy .Kind )
409- 		return  nil 
410415	}
416+ 	return  nil 
411417}
412418
413419// handleKarpenterNodeClaim handles the case where the controller is a Karpenter NodeClaim 
@@ -424,8 +430,12 @@ func (e *NodeExpander) handleKarpenterNodeClaim(ctx context.Context, pod *corev1
424430	// Check if the NodeClaim has owner references 
425431	ownerRefs  :=  nodeClaim .GetOwnerReferences ()
426432	var  nodeClaimParent  * metav1.OwnerReference 
433+ 	hasNodePoolParent  :=  false 
427434
428435	for  _ , owner  :=  range  ownerRefs  {
436+ 		if  owner .Kind  ==  constants .KarpenterNodePoolKind  {
437+ 			hasNodePoolParent  =  true 
438+ 		}
429439		if  owner .Controller  !=  nil  &&  * owner .Controller  {
430440			nodeClaimParent  =  & owner 
431441			break 
@@ -437,13 +447,13 @@ func (e *NodeExpander) handleKarpenterNodeClaim(ctx context.Context, pod *corev1
437447		e .logger .Info ("NodeClaim parent is GPUNodeClaim, cloning another to expand node" ,
438448			"nodeClaimName" , controlledBy .Name , "gpuNodeClaimParent" , nodeClaimParent .Name )
439449		return  e .cloneGPUNodeClaim (ctx , pod , preparedNode , nodeClaimParent )
440- 	} else  if  nodeClaimParent   !=   nil  {
441- 		// No GPUNodeClaim parent , create karpenter  NodeClaim directly with special label identifier 
450+ 	} else  if  hasNodePoolParent  {
451+ 		// owned by Karpenter node pool , create NodeClaim directly with special label identifier 
442452		e .logger .Info ("NodeClaim owned by Karpenter Pool, creating Karpenter NodeClaim to expand node" ,
443453			"nodeClaimName" , controlledBy .Name )
444454		return  e .createKarpenterNodeClaimDirect (ctx , pod , preparedNode , nodeClaim )
445455	} else  {
446- 		return  fmt .Errorf ("NodeClaim has no parent, can not expand node, should not happen" )
456+ 		return  fmt .Errorf ("NodeClaim has no valid  parent, can not expand node, should not happen" )
447457	}
448458}
449459
0 commit comments