Skip to content

Commit bcc4bc1

Browse files
authored
Merge pull request #1351 from martindevans/cleanup_mtmd
General MTMD Improvements & Cleanup
2 parents c932cc6 + 2eed441 commit bcc4bc1

File tree

9 files changed

+342
-331
lines changed

9 files changed

+342
-331
lines changed

LLama.Unittest/MtmdWeightsTests.cs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public MtmdWeightTests()
3030

3131
_mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable.");
3232

33-
_mtmdWeights = MtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams);
33+
_mtmdWeights = Task.Run(async () => await MtmdWeights.LoadFromFileAsync(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams)).Result;
3434
_context = _llamaWeights.CreateContext(@params);
3535
}
3636

@@ -53,7 +53,10 @@ private SafeMtmdInputChunks TokenizeWithEmbed(Func<SafeMtmdEmbed> loadEmbed)
5353
Assert.True(embed.Nx > 0);
5454
Assert.True(embed.Ny > 0);
5555
Assert.False(embed.IsAudio);
56-
Assert.True(embed.GetDataSpan().Length > 0);
56+
57+
Assert.True(embed.ByteCount > 0);
58+
using var mem = embed.GetData();
59+
Assert.True(mem.Data.Length > 0);
5760

5861
var status = _mtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks);
5962
Assert.Equal(0, status);
@@ -71,6 +74,16 @@ private void AssertChunksEvaluate(SafeMtmdInputChunks chunks)
7174
Assert.True(nPast > 0);
7275
}
7376

77+
[Fact, Trait("Category", "NoCI")]
78+
public void BasicPropertyChecks()
79+
{
80+
Assert.False(_mtmdWeights.SupportsAudio);
81+
Assert.True(_mtmdWeights.SupportsVision);
82+
Assert.False(_mtmdWeights.UsesMRope);
83+
Assert.True(_mtmdWeights.UsesNonCausalAttention);
84+
Assert.Equal(-1, _mtmdWeights.AudioBitrate);
85+
}
86+
7487
[Fact,Trait("Category", "NoCI")]
7588
public void EmbedImageAsFileName()
7689
{
@@ -125,8 +138,8 @@ public void TokenizeProvidesChunkMetadata()
125138

126139
Assert.True(imageChunks > 0);
127140
Assert.True(totalTokens > 0);
128-
Assert.Equal(totalTokens, _mtmdWeights.CountTokens(chunks));
129-
Assert.Equal(totalPositions, _mtmdWeights.CountPositions(chunks));
141+
Assert.Equal(totalTokens, chunks.CountTokens());
142+
Assert.Equal(totalPositions, chunks.CountPositions());
130143
Assert.True(_mtmdWeights.SupportsVision);
131144
Assert.False(_mtmdWeights.SupportsAudio);
132145

LLama/Batched/Conversation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, MtmdWeights c
101101
}
102102
}
103103

104-
var totalPositions = (int)clipModel.CountPositions(chunks);
104+
var totalPositions = (int)chunks.CountPositions();
105105
return new MtmdChunkSequence(chunks, textTokens, totalPositions);
106106
}
107107

LLama/LLamaExecutorBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, boo
362362
tokens.Add(token);
363363
}
364364

365-
var totalPositions = (int)ClipModel.CountPositions(chunks);
365+
var totalPositions = (int)chunks.CountPositions();
366366
var fillerToken = GetFillerToken(marker);
367367

368368
if (replaceExisting)

LLama/MtmdWeights.cs

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,79 @@
1-
21
using System;
32
using System.Threading;
43
using System.Threading.Tasks;
4+
using LLama.Exceptions;
55
using LLama.Native;
66

77
namespace LLama;
88

99
/// <summary>
1010
/// Lightweight wrapper around the MTMD native context and its helpers.
1111
/// </summary>
12-
public sealed class MtmdWeights : IDisposable
12+
public sealed class MtmdWeights
13+
: IDisposable
1314
{
15+
/// <summary>
16+
/// The native handle, which is used in the native APIs
17+
/// </summary>
18+
/// <remarks>Be careful how you use this!</remarks>
1419
public SafeMtmdModelHandle NativeHandle { get; }
1520

1621
private MtmdWeights(SafeMtmdModelHandle handle)
1722
{
1823
NativeHandle = handle ?? throw new ArgumentNullException(nameof(handle));
1924
}
2025

26+
/// <summary>
27+
/// Load weights into memory
28+
/// </summary>
29+
/// <param name="mmProject">Path to the mmproj file</param>
30+
/// <param name="textModel">The text model</param>
31+
/// <param name="mtmdCtxParams">Parameters for MTMD context creation</param>
32+
/// <returns></returns>
2133
public static MtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams)
2234
{
23-
if (mmProject == null) throw new ArgumentNullException(nameof(mmProject));
24-
if (textModel == null) throw new ArgumentNullException(nameof(textModel));
25-
if (mtmdCtxParams == null) throw new ArgumentNullException(nameof(mtmdCtxParams));
26-
27-
var handle = SafeMtmdModelHandle.LoadFromFile(mmProject, textModel, mtmdCtxParams);
28-
return new MtmdWeights(handle);
35+
return new MtmdWeights(SafeMtmdModelHandle.LoadFromFile(
36+
mmProject ?? throw new ArgumentNullException(nameof(mmProject)),
37+
textModel ?? throw new ArgumentNullException(nameof(textModel)),
38+
mtmdCtxParams ?? throw new ArgumentNullException(nameof(mtmdCtxParams))
39+
));
2940
}
3041

31-
public static Task<MtmdWeights> LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default)
42+
/// <summary>
43+
/// Load weights into memory
44+
/// </summary>
45+
/// <param name="mmProject">Path to the mmproj file</param>
46+
/// <param name="textModel">The text model</param>
47+
/// <param name="mtmdCtxParams">Parameters for MTMD context creation</param>
48+
/// <param name="token"></param>
49+
/// <returns></returns>
50+
public static async Task<MtmdWeights> LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default)
3251
{
33-
return Task.Run(() => LoadFromFile(mmProject, textModel, mtmdCtxParams), token);
52+
if (mmProject == null)
53+
throw new ArgumentNullException(nameof(mmProject));
54+
if (textModel == null)
55+
throw new ArgumentNullException(nameof(textModel));
56+
if (mtmdCtxParams == null)
57+
throw new ArgumentNullException(nameof(mtmdCtxParams));
58+
59+
var model = await Task.Run(() =>
60+
{
61+
try
62+
{
63+
// Load the model
64+
return LoadFromFile(mmProject, textModel, mtmdCtxParams);
65+
}
66+
catch (LoadWeightsFailedException)
67+
{
68+
// Convert a LoadWeightsFailedException into a cancellation exception if possible.
69+
token.ThrowIfCancellationRequested();
70+
71+
// Ok the weights failed to load for some reason other than cancellation.
72+
throw;
73+
}
74+
}, token);
75+
76+
return model;
3477
}
3578

3679
/// <summary>
@@ -73,15 +116,31 @@ public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, r
73116
public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref int nPast, int seqId, int nBatch)
74117
=> NativeHandle.DecodeImageChunk(chunkPtr, llamaContext, encodedEmbeddings, ref nPast, seqId, nBatch);
75118

76-
public ulong CountTokens(SafeMtmdInputChunks chunks) => NativeHandle.CountTokens(chunks);
77-
78-
public long CountPositions(SafeMtmdInputChunks chunks) => NativeHandle.CountPositions(chunks);
79-
119+
/// <summary>
120+
/// Indicates whether the model supports vision inputs.
121+
/// </summary>
80122
public bool SupportsVision => NativeHandle.SupportVision();
123+
124+
/// <summary>
125+
/// Indicates whether the model supports audio inputs.
126+
/// </summary>
81127
public bool SupportsAudio => NativeHandle.SupportAudio();
128+
129+
/// <summary>
130+
/// Indicates whether the model decodes using the non-causal path.
131+
/// </summary>
82132
public bool UsesNonCausalAttention => NativeHandle.DecodeUseNonCausal();
133+
134+
/// <summary>
135+
/// Indicates whether the model decodes using multi-scale RoPE.
136+
/// </summary>
83137
public bool UsesMRope => NativeHandle.DecodeUseMRope();
138+
139+
/// <summary>
140+
/// Gets the audio bitrate advertised by the model.
141+
/// </summary>
84142
public int AudioBitrate => NativeHandle.GetAudioBitrate();
85143

144+
/// <inheritdoc />
86145
public void Dispose() => NativeHandle.Dispose();
87146
}

0 commit comments

Comments
 (0)