Skip to content

Commit 1b054ca

Browse files
committed
feat: implement automatic token refresh for API calls in ProjectContext and related services
- Fixes #114 - Added `ExecuteWithTokenRefreshAsync` methods to handle API calls with automatic token refresh on authentication failures. - Introduced `RefreshConnectionAsync` to re-acquire the MSAL token when needed. - Updated various services and view models to utilize the new token refresh mechanism, ensuring seamless API interactions.
1 parent 3f69534 commit 1b054ca

File tree

7 files changed

+171
-68
lines changed

7 files changed

+171
-68
lines changed

Source/TeamMate/Model/ProjectContext.cs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
using Microsoft.TeamFoundation.SourceControl.WebApi;
77
using Microsoft.TeamFoundation.WorkItemTracking.WebApi;
88
using Microsoft.TeamFoundation.WorkItemTracking.WebApi.Models;
9+
using Microsoft.VisualStudio.Services.Common;
910
using Microsoft.VisualStudio.Services.WebApi;
11+
using System;
1012
using System.Collections.Generic;
1113
using System.Collections.ObjectModel;
14+
using System.Threading;
1215
using System.Threading.Tasks;
1316
using Microsoft.VisualStudio.Services.Graph.Client;
1417
using System.Runtime.Versioning;
@@ -35,6 +38,75 @@ public ProjectContext(ProjectReference reference)
3538

3639
public VssConnection Connection { get; set; }
3740

41+
/// <summary>
42+
/// Function to refresh the connection when the token expires.
43+
/// </summary>
44+
public Func<CancellationToken, Task> RefreshConnectionAsync { get; set; }
45+
46+
private readonly SemaphoreSlim _refreshLock = new SemaphoreSlim(1, 1);
47+
48+
/// <summary>
49+
/// Executes an API call with automatic token refresh on authentication failure.
50+
/// </summary>
51+
public async Task<T> ExecuteWithTokenRefreshAsync<T>(Func<Task<T>> apiCall, CancellationToken cancellationToken = default)
52+
{
53+
try
54+
{
55+
return await apiCall();
56+
}
57+
catch (VssUnauthorizedException)
58+
{
59+
// Token expired, refresh and retry
60+
await RefreshConnectionIfNeededAsync(cancellationToken);
61+
return await apiCall();
62+
}
63+
catch (VssServiceException ex) when (ex.Message.Contains("VS30063") || ex.Message.Contains("not authorized"))
64+
{
65+
// Authorization error, refresh and retry
66+
await RefreshConnectionIfNeededAsync(cancellationToken);
67+
return await apiCall();
68+
}
69+
}
70+
71+
/// <summary>
72+
/// Executes an API call with automatic token refresh on authentication failure (for void returns).
73+
/// </summary>
74+
public async Task ExecuteWithTokenRefreshAsync(Func<Task> apiCall, CancellationToken cancellationToken = default)
75+
{
76+
try
77+
{
78+
await apiCall();
79+
}
80+
catch (VssUnauthorizedException)
81+
{
82+
// Token expired, refresh and retry
83+
await RefreshConnectionIfNeededAsync(cancellationToken);
84+
await apiCall();
85+
}
86+
catch (VssServiceException ex) when (ex.Message.Contains("VS30063") || ex.Message.Contains("not authorized"))
87+
{
88+
// Authorization error, refresh and retry
89+
await RefreshConnectionIfNeededAsync(cancellationToken);
90+
await apiCall();
91+
}
92+
}
93+
94+
private async Task RefreshConnectionIfNeededAsync(CancellationToken cancellationToken)
95+
{
96+
await _refreshLock.WaitAsync(cancellationToken);
97+
try
98+
{
99+
if (RefreshConnectionAsync != null)
100+
{
101+
await RefreshConnectionAsync(cancellationToken);
102+
}
103+
}
104+
finally
105+
{
106+
_refreshLock.Release();
107+
}
108+
}
109+
38110
public WorkItemTrackingHttpClient WorkItemTrackingClient { get; set; }
39111

40112
public GitHttpClient GitHttpClient { get; set; }

Source/TeamMate/Services/SearchService.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ public async Task<SearchResults> AdoSearch(SearchExpression searchExpression, Ca
8282

8383
await ChaosMonkey.ChaosAsync(ChaosScenarios.VstsSearch);
8484

85-
var result = await pc.WorkItemTrackingClient.QueryAsync(query);
85+
var result = await pc.ExecuteWithTokenRefreshAsync(async () =>
86+
{
87+
var client = pc.Connection.GetClient<Microsoft.TeamFoundation.WorkItemTracking.WebApi.WorkItemTrackingHttpClient>();
88+
return await client.QueryAsync(query);
89+
});
8690
var workItems = result.WorkItems.Select(wi => CreateWorkItemViewModel(wi));
8791
var searchResults = workItems.Select(wi => new SearchResult(wi, SearchResultSource.Ado)).ToArray();
8892
return new SearchResults(searchResults, result.QueryResult.WorkItems.Count());

Source/TeamMate/Services/VstsConnectionService.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,24 @@ private static void RegisterTokenCache(ITokenCache tokenCache, string cacheFileP
158158
});
159159
}
160160

161+
/// <summary>
162+
/// Refreshes the connection for an existing ProjectContext by re-acquiring the MSAL token.
163+
/// </summary>
164+
private async Task RefreshConnectionAsync(ProjectContext projectContext, CancellationToken cancellationToken)
165+
{
166+
using (Log.PerformanceBlock("Refreshing connection token for {0}", projectContext.ProjectInfo?.ProjectCollectionUri))
167+
{
168+
var connection = await this.CreateConnectionAsync(projectContext.ProjectInfo.ProjectCollectionUri, cancellationToken);
169+
170+
// Update the connection and all the clients
171+
projectContext.Connection = connection;
172+
projectContext.WorkItemTrackingClient = connection.GetClient<WorkItemTrackingHttpClient>();
173+
projectContext.WorkItemTrackingBatchClient = connection.GetClient<WorkItemTrackingBatchHttpClient>();
174+
projectContext.GitHttpClient = connection.GetClient<GitHttpClient>();
175+
projectContext.GraphClient = connection.GetClient<GraphHttpClient>();
176+
}
177+
}
178+
161179
public async Task<ProjectReference> ResolveProjectReferenceAsync(Uri projectCollectionUri, string projectName)
162180
{
163181
var connection = await this.CreateConnectionAsync(projectCollectionUri);
@@ -284,6 +302,9 @@ private async Task<ProjectContext> DoConnectAsync(ProjectInfo projectInfo, Cance
284302
projectContext.WorkItemFieldsByName = fields.ToDictionary(f => f.ReferenceName, StringComparer.OrdinalIgnoreCase);
285303
projectContext.RequiredWorkItemFieldNames = GetWorkItemFieldsToPrefetch(projectContext.WorkItemFieldsByName);
286304

305+
// Set up the refresh connection delegate
306+
projectContext.RefreshConnectionAsync = async (ct) => await RefreshConnectionAsync(projectContext, ct);
307+
287308
return projectContext;
288309
}
289310
catch (Exception e)

Source/TeamMate/ViewModels/PullRequestQueryViewModel.cs

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -97,53 +97,55 @@ private async Task DoRefreshAsync(NotificationScope notificationScope)
9797
{
9898
await ChaosMonkey.ChaosAsync(ChaosScenarios.PullRequestQueryExecution);
9999

100-
List<Task> tasks = new List<Task>();
101-
102-
var queryAsyncTask = projectContext.GitHttpClient.GetPullRequestsByProjectAsync(
103-
query.ProjectName,
104-
query.GitPullRequestSearchCriteria);
105-
106-
tasks.Add(queryAsyncTask);
107-
108-
await Task.WhenAll(tasks.ToArray());
109-
110-
List<Task<List<GitPullRequestIteration>>> iterationTasks = new List<Task<List<GitPullRequestIteration>>>();
111-
112-
PullRequestRowViewModel[] pullRequests = null;
113-
if (this.queryInfo.Filter == PullRequestQueryFilter.None)
114-
{
115-
pullRequests = queryAsyncTask.Result.Select(r => CreateViewModel(r, projectContext)).ToArray();
116-
}
117-
else if (this.queryInfo.Filter == PullRequestQueryFilter.NeedsAction)
118-
{
119-
pullRequests = queryAsyncTask.Result.Select(r => CreateViewModel(r, projectContext)).Where(x => x.IsNeedsAction).ToArray();
120-
}
121-
else
100+
// Execute with token refresh in case the token has expired
101+
var pullRequests = await projectContext.ExecuteWithTokenRefreshAsync(async () =>
122102
{
123-
pullRequests = queryAsyncTask.Result.Select(r => CreateViewModel(r, projectContext)).ToArray();
124-
}
125-
126-
foreach (var pullRequest in pullRequests)
127-
{
128-
var asyncTask = projectContext.GitHttpClient.GetPullRequestIterationsAsync(
129-
pullRequest.Reference.Repository.Id,
130-
pullRequest.Reference.PullRequestId);
103+
var gitClient = projectContext.Connection.GetClient<Microsoft.TeamFoundation.SourceControl.WebApi.GitHttpClient>();
131104

132-
iterationTasks.Add(asyncTask);
133-
134-
pullRequest.Url = projectContext.HyperlinkFactory.GetPullRequestUrl(
135-
pullRequest.Reference.PullRequestId,
105+
var queryResult = await gitClient.GetPullRequestsByProjectAsync(
136106
query.ProjectName,
137-
pullRequest.Reference.Repository.Name);
138-
}
139-
140-
await Task.WhenAll(iterationTasks.ToArray());
141-
142-
int i = 0;
143-
foreach (var pullRequest in pullRequests)
144-
{
145-
pullRequest.Iterations = iterationTasks[i++].Result;
146-
}
107+
query.GitPullRequestSearchCriteria);
108+
109+
PullRequestRowViewModel[] prs = null;
110+
if (this.queryInfo.Filter == PullRequestQueryFilter.None)
111+
{
112+
prs = queryResult.Select(r => CreateViewModel(r, projectContext)).ToArray();
113+
}
114+
else if (this.queryInfo.Filter == PullRequestQueryFilter.NeedsAction)
115+
{
116+
prs = queryResult.Select(r => CreateViewModel(r, projectContext)).Where(x => x.IsNeedsAction).ToArray();
117+
}
118+
else
119+
{
120+
prs = queryResult.Select(r => CreateViewModel(r, projectContext)).ToArray();
121+
}
122+
123+
List<Task<List<GitPullRequestIteration>>> iterationTasks = new List<Task<List<GitPullRequestIteration>>>();
124+
125+
foreach (var pullRequest in prs)
126+
{
127+
var asyncTask = gitClient.GetPullRequestIterationsAsync(
128+
pullRequest.Reference.Repository.Id,
129+
pullRequest.Reference.PullRequestId);
130+
131+
iterationTasks.Add(asyncTask);
132+
133+
pullRequest.Url = projectContext.HyperlinkFactory.GetPullRequestUrl(
134+
pullRequest.Reference.PullRequestId,
135+
query.ProjectName,
136+
pullRequest.Reference.Repository.Name);
137+
}
138+
139+
await Task.WhenAll(iterationTasks.ToArray());
140+
141+
int i = 0;
142+
foreach (var pullRequest in prs)
143+
{
144+
pullRequest.Iterations = iterationTasks[i++].Result;
145+
}
146+
147+
return prs;
148+
});
147149

148150
OnQueryCompleted(projectContext, pullRequests, notificationScope);
149151
}

Source/TeamMate/ViewModels/QueryHierarchyItemViewModel.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ public QueryHierarchyItemViewModel(ProjectContext projectContext, QueryHierarchy
4040

4141
protected override async Task<IEnumerable<TreeItemViewModelBase>> LoadChildrenAsync()
4242
{
43-
// Get a fresh client from the connection to ensure we have the latest token
44-
var client = this.projectContext.Connection.GetClient<WorkItemTrackingHttpClient>();
4543
var project = this.projectContext.ProjectName;
4644

4745
IEnumerable<QueryHierarchyItem> children = null;
@@ -50,14 +48,22 @@ protected override async Task<IEnumerable<TreeItemViewModelBase>> LoadChildrenAs
5048
if (this.Item.IsFolder == true)
5149
{
5250
await ChaosMonkey.ChaosAsync(ChaosScenarios.LoadQueryFolder);
53-
var selfWithChildren = await client.GetQueryAsync(project, this.Item.Id.ToString(), depth: 1, expand: QueryExpand.Wiql);
54-
children = selfWithChildren.Children;
51+
children = await this.projectContext.ExecuteWithTokenRefreshAsync(async () =>
52+
{
53+
var client = this.projectContext.Connection.GetClient<WorkItemTrackingHttpClient>();
54+
var selfWithChildren = await client.GetQueryAsync(project, this.Item.Id.ToString(), depth: 1, expand: QueryExpand.Wiql);
55+
return selfWithChildren.Children;
56+
});
5557
}
5658
}
5759
else
5860
{
5961
// This is the root, find root query folders (TODO: consider depth: 1 for 1 less call)
60-
children = await client.GetQueriesAsync(project, expand: QueryExpand.Wiql);
62+
children = await this.projectContext.ExecuteWithTokenRefreshAsync(async () =>
63+
{
64+
var client = this.projectContext.Connection.GetClient<WorkItemTrackingHttpClient>();
65+
return await client.GetQueriesAsync(project, expand: QueryExpand.Wiql);
66+
});
6167
}
6268

6369
if (children == null)

Source/TeamMate/ViewModels/WorkItemQueryViewModel.cs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,25 +125,19 @@ private async Task DoRefreshAsync(NotificationScope notificationScope)
125125

126126
await ChaosMonkey.ChaosAsync(ChaosScenarios.WorkItemQueryExecution);
127127

128-
List<Task> tasks = new List<Task>();
129-
130-
var queryAsyncTask = projectContext.WorkItemTrackingClient.QueryAsync(query);
131-
tasks.Add(queryAsyncTask);
132-
133-
Task<QueryHierarchyItem> getQueryTask = null;
134-
if (query.QueryId != Guid.Empty)
128+
// Execute with token refresh in case the token has expired
129+
queryResult = await projectContext.ExecuteWithTokenRefreshAsync(async () =>
135130
{
136-
getQueryTask = projectContext.WorkItemTrackingClient.GetQueryAsync(query.ProjectName, query.QueryId.ToString(), QueryExpand.All);
137-
tasks.Add(getQueryTask);
138-
}
131+
var client = projectContext.Connection.GetClient<Microsoft.TeamFoundation.WorkItemTracking.WebApi.WorkItemTrackingHttpClient>();
132+
var result = await client.QueryAsync(query);
139133

140-
await Task.WhenAll(tasks.ToArray());
134+
if (query.QueryId != Guid.Empty)
135+
{
136+
result.QueryHierarchyItem = await client.GetQueryAsync(query.ProjectName, query.QueryId.ToString(), QueryExpand.All);
137+
}
141138

142-
queryResult = queryAsyncTask.Result;
143-
if (getQueryTask != null)
144-
{
145-
queryResult.QueryHierarchyItem = getQueryTask.Result;
146-
}
139+
return result;
140+
});
147141

148142
OnQueryCompleted(projectContext, queryResult, notificationScope);
149143
}

Source/TeamMate/ViewModels/WorkItemsPageViewModel.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,12 @@ private async void EditTags()
206206
{
207207
using (this.StatusService.BusyIndicator())
208208
{
209-
var batchClient = this.SessionService.Session.ProjectContext.WorkItemTrackingBatchClient;
210-
var results = await batchClient.BatchUpdateWorkItemsAsync(updateRequests);
209+
var projectContext = this.SessionService.Session.ProjectContext;
210+
await projectContext.ExecuteWithTokenRefreshAsync(async () =>
211+
{
212+
var batchClient = projectContext.Connection.GetClient<Microsoft.Tools.TeamMate.TeamFoundation.WebApi.WorkItemTracking.WorkItemTrackingBatchHttpClient>();
213+
await batchClient.BatchUpdateWorkItemsAsync(updateRequests);
214+
});
211215

212216
// KLUDGE: To refresh the work item row view models in the easiest way.
213217
// Ideally, we can invaliadte each row, but the returned results might not return all the required fields that we

0 commit comments

Comments
 (0)