-
Notifications
You must be signed in to change notification settings - Fork 148
Open
Labels
type:bugBugBug
Description
Description of the bug:
Bug is harsh - it works on CPU - but even there it's less efficient than it could be.
Example model:
model_path = "roblox/voice-safety-classifier-v2"
model = WavLMForSequenceClassification.from_pretrained(
model_path, num_labels=len(labels_name_list)
)
dummy_input = (torch.randn(1, 80000),)
edge_model = litert_torch.convert(model, dummy_input)
edge_model.export("./out.tflite")
Creates 2 GATHER_NDs early in the model that are in fact just doing RESHAPE from [512] to [1,512]. That section could all be simplified further as well (lots of unnecessary reshapes are generated), but biggest issue is the GATHER_NDs, which prevent GPU use. (There are other things in model that prevent GPU use, but they have other fixes)
Actual vs expected behavior:
Actual: GATHER_NDs unnecessarily created.
Expected: Where possible GPU supported ops are preferred - if a RESHAPE, TRANSPOSE or STRIDED_SLICE would do, prefer those.
Any other information you'd like to share?
GATHER_ND is not expected to be supported on GPU in near future (google-ai-edge/LiteRT#5197)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
type:bugBugBug