Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions SentenceTransformers.Qwen3/src/SentenceEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ public static async Task<SentenceEncoder> CreateAsync(
SessionOptions sessionOptions = null,
string modelUrl = null,
string downloadToPath = null,
IProgress<float> progress = null,
CancellationToken cancellationToken = default)
{
var path = downloadToPath ?? Path.Combine(Path.GetTempPath(), "SentenceTransformers.Qwen3", "qwen3-model.onnx");
Directory.CreateDirectory(Path.GetDirectoryName(path)!);
await DownloadModelAsync(modelUrl ?? DefaultModelUrl, path, cancellationToken);
await DownloadModelAsync(modelUrl ?? DefaultModelUrl, path, progress, cancellationToken);
return new SentenceEncoder(sessionOptions, path);
}

Expand Down Expand Up @@ -216,7 +217,7 @@ private static string FindInputName(InferenceSession session, string preferred)
/// Downloads the ONNX model from <paramref name="modelUrl"/> to <paramref name="localPath"/>.
/// Only one download runs at a time. On failure, any partial file at <paramref name="localPath"/> is deleted.
/// </summary>
public static async Task DownloadModelAsync(string modelUrl, string localPath, CancellationToken cancellationToken = default)
public static async Task DownloadModelAsync(string modelUrl, string localPath, IProgress<float> progress = null, CancellationToken cancellationToken = default)
{
if (File.Exists(localPath)) return;

Expand All @@ -239,6 +240,7 @@ public static async Task DownloadModelAsync(string modelUrl, string localPath, C
try
{
response.EnsureSuccessStatusCode();
long? totalFileSize = response.Content.Headers.ContentLength;
var supportsRange = response.Headers.AcceptRanges.Contains("bytes");

await using var fileStream = new FileStream(localPath, FileMode.Create, FileAccess.Write, FileShare.None, buffer.Length, true);
Expand All @@ -255,6 +257,10 @@ public static async Task DownloadModelAsync(string modelUrl, string localPath, C
{
await fileStream.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken);
totalBytesRead += bytesRead;
if (totalFileSize.HasValue && progress != null)
{
progress.Report((float)totalBytesRead / totalFileSize.Value);
}
}
finished = true;
}
Expand All @@ -279,6 +285,10 @@ public static async Task DownloadModelAsync(string modelUrl, string localPath, C
response.Dispose();
response = await _downloadClient.SendAsync(newRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken);
response.EnsureSuccessStatusCode();
if (!supportsRange || totalBytesRead == 0)
{
totalFileSize = response.Content.Headers.ContentLength;
}
}
}
}
Expand Down