@@ -6132,16 +6132,14 @@ export class SamModel extends SamPreTrainedModel {
61326132 // Compute the image embeddings if they are missing
61336133 model_inputs = {
61346134 ...model_inputs ,
6135- ...( await this . get_image_embeddings ( model_inputs ) ) ,
6136- } ;
6135+ ...( await this . get_image_embeddings ( model_inputs ) )
6136+ }
6137+ } else {
6138+ model_inputs = { ...model_inputs } ;
61376139 }
61386140
6139- if ( ! model_inputs . input_labels && model_inputs . input_points ) {
6140- // Set default input labels if they are missing
6141- const shape = model_inputs . input_points . dims . slice ( 0 , - 1 ) ;
6142- const numElements = shape . reduce ( ( a , b ) => a * b , 1 ) ;
6143- model_inputs . input_labels = new Tensor ( 'int64' , new BigInt64Array ( numElements ) . fill ( 1n ) , shape ) ;
6144- }
6141+ // Set default input labels if they are missing
6142+ model_inputs . input_labels ??= ones ( model_inputs . input_points . dims . slice ( 0 , - 1 ) ) ;
61456143
61466144 const decoder_inputs = {
61476145 image_embeddings : model_inputs . image_embeddings ,
@@ -6190,6 +6188,101 @@ export class SamImageSegmentationOutput extends ModelOutput {
61906188}
61916189//////////////////////////////////////////////////
61926190
6191+ //////////////////////////////////////////////////
6192+ export class Sam2ImageSegmentationOutput extends ModelOutput {
6193+ /**
6194+ * @param {Object } output The output of the model.
6195+ * @param {Tensor } output.iou_scores The output logits of the model.
6196+ * @param {Tensor } output.pred_masks Predicted boxes.
6197+ * @param {Tensor } output.object_score_logits Logits for the object score, indicating if an object is present.
6198+ */
6199+ constructor ( { iou_scores, pred_masks, object_score_logits } ) {
6200+ super ( ) ;
6201+ this . iou_scores = iou_scores ;
6202+ this . pred_masks = pred_masks ;
6203+ this . object_score_logits = object_score_logits ;
6204+ }
6205+ }
6206+
6207+ export class EdgeTamPreTrainedModel extends PreTrainedModel { }
6208+
6209+ /**
6210+ * EdgeTAM for generating segmentation masks, given an input image
6211+ * and optional 2D location and bounding boxes.
6212+ */
6213+ export class EdgeTamModel extends EdgeTamPreTrainedModel {
6214+
6215+ /**
6216+ * Compute image embeddings and positional image embeddings, given the pixel values of an image.
6217+ * @param {Object } model_inputs Object containing the model inputs.
6218+ * @param {Tensor } model_inputs.pixel_values Pixel values obtained using a `Sam2Processor`.
6219+ * @returns {Promise<Record<String, Tensor>> } The image embeddings.
6220+ */
6221+ async get_image_embeddings ( { pixel_values } ) {
6222+ // in:
6223+ // - pixel_values: tensor.float32[batch_size,3,1024,1024]
6224+ //
6225+ // out:
6226+ // - image_embeddings.0: tensor.float32[batch_size,32,256,256]
6227+ // - image_embeddings.1: tensor.float32[batch_size,64,128,128]
6228+ // - image_embeddings.2: tensor.float32[batch_size,256,64,64]
6229+ return await encoderForward ( this , { pixel_values } ) ;
6230+ }
6231+
6232+ async forward ( model_inputs ) {
6233+ // @ts -expect-error ts(2339)
6234+ const { num_feature_levels } = this . config . vision_config ;
6235+ const image_embeddings_name = Array . from ( { length : num_feature_levels } , ( _ , i ) => `image_embeddings.${ i } ` ) ;
6236+
6237+ if ( image_embeddings_name . some ( name => ! model_inputs [ name ] ) ) {
6238+ // Compute the image embeddings if they are missing
6239+ model_inputs = {
6240+ ...model_inputs ,
6241+ ...( await this . get_image_embeddings ( model_inputs ) )
6242+ }
6243+ } else {
6244+ model_inputs = { ...model_inputs } ;
6245+ }
6246+
6247+ if ( model_inputs . input_points ) {
6248+ if ( model_inputs . input_boxes && model_inputs . input_boxes . dims [ 1 ] !== 1 ) {
6249+ throw new Error ( 'When both `input_points` and `input_boxes` are provided, the number of boxes per image must be 1.' ) ;
6250+ }
6251+ const shape = model_inputs . input_points . dims ;
6252+ model_inputs . input_labels ??= ones ( shape . slice ( 0 , - 1 ) ) ;
6253+ model_inputs . input_boxes ??= full ( [ shape [ 0 ] , 0 , 4 ] , 0.0 ) ;
6254+
6255+ } else if ( model_inputs . input_boxes ) { // only boxes
6256+ const shape = model_inputs . input_boxes . dims ;
6257+ model_inputs . input_labels = full ( [ shape [ 0 ] , shape [ 1 ] , 0 ] , - 1n ) ;
6258+ model_inputs . input_points = full ( [ shape [ 0 ] , 1 , 0 , 2 ] , 0.0 ) ;
6259+
6260+ } else {
6261+ throw new Error ( 'At least one of `input_points` or `input_boxes` must be provided.' ) ;
6262+ }
6263+
6264+ const prompt_encoder_mask_decoder_session = this . sessions [ 'prompt_encoder_mask_decoder' ] ;
6265+ const decoder_inputs = pick ( model_inputs , prompt_encoder_mask_decoder_session . inputNames ) ;
6266+
6267+ // Returns:
6268+ // - iou_scores: tensor.float32[batch_size,num_boxes_or_points,3]
6269+ // - pred_masks: tensor.float32[batch_size,num_boxes_or_points,3,256,256]
6270+ // - object_score_logits: tensor.float32[batch_size,num_boxes_or_points,1]
6271+ return await sessionRun ( prompt_encoder_mask_decoder_session , decoder_inputs ) ;
6272+ }
6273+
6274+ /**
6275+ * Runs the model with the provided inputs
6276+ * @param {Object } model_inputs Model inputs
6277+ * @returns {Promise<Sam2ImageSegmentationOutput> } Object containing segmentation outputs
6278+ */
6279+ async _call ( model_inputs ) {
6280+ return new Sam2ImageSegmentationOutput ( await super . _call ( model_inputs ) ) ;
6281+ }
6282+ }
6283+ //////////////////////////////////////////////////
6284+
6285+
61936286//////////////////////////////////////////////////
61946287// MarianMT models
61956288export class MarianPreTrainedModel extends PreTrainedModel { }
@@ -8384,7 +8477,10 @@ const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
83848477 [ 'maskformer' , [ 'MaskFormerForInstanceSegmentation' , MaskFormerForInstanceSegmentation ] ] ,
83858478] ) ;
83868479
8387- const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map ( [ [ 'sam' , [ 'SamModel' , SamModel ] ] ] ) ;
8480+ const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map ( [
8481+ [ 'sam' , [ 'SamModel' , SamModel ] ] ,
8482+ [ 'edgetam' , [ 'EdgeTamModel' , EdgeTamModel ] ] ,
8483+ ] ) ;
83888484
83898485const MODEL_FOR_CTC_MAPPING_NAMES = new Map ( [
83908486 [ 'wav2vec2' , [ 'Wav2Vec2ForCTC' , Wav2Vec2ForCTC ] ] ,
0 commit comments